! [ -e /content ] && pip install -Uqq xcube # upgrade xcube on colab
Boot L2R
from fastai.data.core import *
from xcube.l2r.all import *
Make sure we have that “beast”:
;
ic(torch.cuda.get_device_name(default_device()))0), torch.cuda.get_device_name(default_device()))
test_eq(torch.cuda.get_device_name(0))
test_eq(default_device(), torch.device(print(f"GPU memory = {torch.cuda.get_device_properties(default_device()).total_memory/1024**3}GB")
ic| torch.cuda.get_device_name(default_device()): 'Quadro RTX 8000'
GPU memory = 44.99969482421875GB
Setting some environment variables:
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
Setting defaults for pandas and matplotlib:
# Set the default figure size
"figure.figsize"] = (8, 4)
plt.rcParams[# Set pandas column width
'display.max_colwidth', None) pd.set_option(
Altering some default jupyter settings:
from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = "last" # "all"
In this tutorial we will find a needle in the haystack with mutual infomation gain:
Mutual-Information Computation
= untar_xxx(XURLs.MIMIC3_L2R) source
(#11) [Path('/home/deb/.xcube/data/mimic3_l2r/info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_descriptions.csv'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_tok_lbl_info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_desc.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/p_TL.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/trn_val_split.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_tok.ft'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_lbl.ft'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k.csv'),Path('/home/deb/.xcube/data/mimic3_l2r/scored_tokens.pth')...]
= source/'mimic3-9k.csv'
data = pd.read_csv(data,
df =0,
header=['subject_id', 'hadm_id', 'text', 'labels', 'length', 'is_valid'],
names={'subject_id': str, 'hadm_id': str, 'text': str, 'labels': str, 'length': np.int64, 'is_valid': bool})
dtype'text', 'labels']] = df[['text', 'labels']].astype(str)
df[[len(df)
52726
3) df.head(
subject_id | hadm_id | text | labels | length | is_valid | |
---|---|---|---|---|---|---|
0 | 86006 | 111912 | admission date discharge date date of birth sex f service surgery allergies patient recorded as having no known allergies to drugs attending first name3 lf chief complaint 60f on coumadin was found slightly drowsy tonight then fell down stairs paramedic found her unconscious and she was intubated w o any medication head ct shows multiple iph transferred to hospital1 for further eval major surgical or invasive procedure none past medical history her medical history is significant for hypertension osteoarthritis involving bilateral knee joints with a dependence on cane for ambulation chronic... | 801.35;348.4;805.06;807.01;998.30;707.24;E880.9;427.31;414.01;401.9;V58.61;V43.64;707.00;E878.1;96.71 | 230 | False |
1 | 85950 | 189769 | admission date discharge date service neurosurgery allergies sulfa sulfonamides attending first name3 lf chief complaint cc cc contact info major surgical or invasive procedure none history of present illness hpi 88m who lives with family had fall yesterday today had decline in mental status ems called pt was unresponsive on arrival went to osh head ct showed large r sdh pt was intubated at osh and transferred to hospital1 for further care past medical history cad s p mi in s p cabg in ventricular aneurysm at that time cath in with occluded rca unable to intervene chf reported ef 1st degre... | 852.25;E888.9;403.90;585.9;250.00;414.00;V45.81;96.71 | 304 | False |
2 | 88025 | 180431 | admission date discharge date date of birth sex f service surgery allergies no known allergies adverse drug reactions attending first name3 lf chief complaint s p fall major surgical or invasive procedure none history of present illness 45f etoh s p fall from window at feet found ambulating and slurring speech on scene intubated en route for declining mental status in the er the patient was found to be bradycardic to the s with bp of systolic she was given atropine dilantin and was started on saline past medical history unknown social history unknown family history unknown physical exam ex... | 518.81;348.4;348.82;801.25;427.89;E882;V49.86;305.00;96.71;38.93 | 359 | False |
The file 'code_desc.pkl'
contains a short description for the labels.
# with open(source/'code_desc.pkl', 'rb') as f: lbs_desc = pickle.load(f)
= load_pickle(source/'code_desc.pkl')
lbs_desc assert isinstance(lbs_desc, dict)
'00.93', '427.31', '00.09']), ('Transplant from cadaver',
test_eq(mapt(lbs_desc.get, ['Atrial fibrillation',
'Other therapeutic ultrasound'))
Note that performing some computations in this notebook on the full dataset is going to take a lot of time. But don’t worry untar_xxx
has already downloaded everything you need. But you can still run the following cells if you want to generate everything from scratch. Preferably, run the following cells on a sampled dataset for quick iterations.
Run the cell below only if you want to sample from the full dataset to create a tiny dataset for the purpose of quick iterations.
Technical Point: If we want to sample to perform quick iterations, we need to make sure the number of data points in the sample is a multiple of bs
. So that we do not have to do a drop_last=True
while creating the Dataloaders
. This is because we are about to do some probability computations, and dropping data points is not a good idea as probabilities would not sum to 1.
= 8
bs = len(df) - len(df)%bs
cut = df[:cut]
df len(df)
52720
= np.arange(0, len(df), bs)
_arr # mask = (_arr > 4000) & (_arr < 5000)
= (_arr > 500) & (_arr < 1000)
mask = np.random.choice(_arr[mask], 1)
_n = df.sample(n=_n, random_state=89, ignore_index=True)
df len(df)
744
3) df.head(
subject_id | hadm_id | text | labels | length | is_valid | |
---|---|---|---|---|---|---|
0 | 2258 | 139169 | admission date discharge date date of birth sex m service cardiothoracic surgery history of present illness the patient is a year old male with a past medical history significant for poorly controlled diabetes mellitus and hypertension as well as known coronary disease and a previous non q myocardial infarction and right coronary artery stenting in he was admitted to an outside hospital on the day prior to admission with unstable angina and found to have borderline positive troponin hypertension and st depressions in the lateral lead he was given aspirin nitrates beta blockers morphine and... | 414.01;998.31;411.1;599.0;412;V45.82;250.00;401.9;530.81;36.13;37.22;36.15;36.19;39.61;39.64;88.56;88.53;33.23;96.56;33.24;78.41 | 1271 | False |
1 | 41217 | 161582 | admission date discharge date date of birth sex m service medicine allergies no known allergies adverse drug reactions attending first name3 lf chief complaint new diagnosis of scc of base of tongue major surgical or invasive procedure egd w biopsy history of present illness yo man with h o cad heavy smoking and new diagnosis of scc of base of tongue with lymph node involvement pt was referred to dr last name stitle ent in for a rt neck mass at that time a cm rt cervical lymph node was palpated and fiberoptic laryngoscopy showed a cm rt base of tongue mass a ct and biopsy were recommended ... | 141.0;507.0;196.0;293.0;519.09;786.30;286.9;427.89;790.29;276.52;414.01;338.3;280.0;272.0;412;V69.4;V15.82;V45.82;V66.7;E879.8;E932.0;31.42;25.01;42.23;43.11;96.6;38.93;99.25;38.93 | 2743 | False |
2 | 30204 | 172114 | admission date discharge date date of birth sex f service medicine allergies etomidate norpace quinidine demerol penicillins lipitor attending doctor first name chief complaint cardiac tamponade s p pulmonary vein isolation major surgical or invasive procedure attempted pulmonary vein isolation pericardiocentesis history of present illness year old woman with a long history of paroxysmal atrial fibrillation refractory to mulitple pharmacologic interventions and multiple cardioversions who presents to the ccu with cardiac tamponade s p pulmonary vein isolation procedure past medical history... | 427.31;998.2;423.3;423.9;573.0;276.6;E878.8;37.34;37.27;37.0;37.21 | 1764 | False |
The mutual information of two jointly discrete random variables X and Y is calculated as a double sum:
\[I(T;L) = \sum_{l \in \mathcal{L}} \sum_{t in \mathcal{T}} P_{(T,L)}(t,l) \log \Bigg(\frac{P_{(T,L)}(t,l)}{P_T(t) P_L(l)} \Bigg)\]
where \(P_{(T,L)}\) is the joint probability mass function of \(T\) and \(L\), and \(P_T\) and \(P_L\) are the marginal probability mass fucntions of \(T\) and \(L\) respectively. To compute \(I\), the only quantity we need to compute is the joint pmf \(P_{(T,L)}\), as the marginal pmfs can be computed from the joint pmf.
With regard to implementation, \(P_{(T,L)}\) can be thought of as a 2x2 tensor as shown below:
= pd.DataFrame(0, columns=['t', 'not t'], index=['lbl', 'not lbl'])
p_TL p_TL
t | not t | |
---|---|---|
lbl | 0 | 0 |
not lbl | 0 | 0 |
…and we need to compute this \(P_{(T,L)}\) for every token-label pair. In other words, we need to fill in the joint
dataframe shown below. Note that each cell in joint
dataframe can be thought of to be further subdivided into a 2x2 grid containing the corresponding p_TL
.
= 8, 200
bs, chnk_sz = MutualInfoGain(df, bs=bs, chnk_sz=chnk_sz, lbs_desc=source/'code_desc.pkl') # provide lbs_desc if you have it info
= info.onehotify() dsets
CPU times: user 597 ms, sys: 244 ms, total: 842 ms
Wall time: 4.21 s
= dsets.vocab
toks, lbs len(toks)*len(lbs) L(toks), L(lbs),
((#10632) ['xxunk','xxpad','xxbos','xxeos','xxfld','xxrep','xxwrep','xxup','xxmaj','the'...],
(#2150) ['008.45','008.8','009.0','009.1','031.0','031.2','038.0','038.10','038.11','038.19'...],
22858800)
= pd.DataFrame(0, columns=range(len(lbs)), index=range(len(toks)))
joint = 'toks (T)'
joint.index.name = 'lbs (L)'
joint.columns.name joint
lbs (L) | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 2140 | 2141 | 2142 | 2143 | 2144 | 2145 | 2146 | 2147 | 2148 | 2149 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
toks (T) | |||||||||||||||||||||
0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
10627 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10628 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10629 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10630 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10631 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10632 rows × 2150 columns
We can perform tensorized computation if we think of p_TL
as a 4 dim tensor of size (len(toks), len(lbs), 2, 2)
. Next, to be able to estimate p_TL
we just need to iterate over the dataset and for each data point and each token-label pair record the p_TL
information in the last two dimension of the tensor p_TL
. And, at the end divide by size of the dataset.
Some more implementation details (Skip this if not iterested):
- We are going to one-hot encode the dataset (both
text
andlabels
field in thedf
). This is done byonehot_dsets
- For efficieny, in reality we are not going to iterate over the dataset one by one, instead we are going to use a dataloader and perform
p_TL
computation on a mini-batch. - Unless you are doing this in 2035 you probably do not have enogh GPU-RAM to fit the entire
p_TL
tensor of dimension(len(toks), len(lbs), 2, 2)
. So we are going to split the lbs dimension into chunks. (Why thelbs
dimension and not thetoks
? Because in XML datsetstoks
are approximately 60000, but the number oflbs
could be really large of the order of millions.) With reagrd to implementation this would mean that instead of one dataloader we would roll with multiple dataloaders. And each dataloader would load the dataset in a way that mini-batches would contain the full one-hot encoding of thetext
field but only a certainchunk
of the one-hot encodedlabels
field indf
. Another way to think about this is that each datapoint, specifically thelabels
are splitted across multiple dataloaders. This way once we are done iterating over one such dataloader we would have filled a ceratin chunk of thejoint
dataframe shown above. And we would fill the entirejoint
only once we are done iterating over all the dataloaders.
= dsets[0]
x, y 1][2].decode(y)), torch.where(y==1)[0])
test_eq(tensor(dsets.tfms[0][-1].decode(x)), torch.where(x==1)[0]) test_eq(tensor(dsets.tfms[
' '.join(L(toks)[torch.where(x==1)[0]])
'xxunk xxbos the and to of was with a on in for mg no patient is he blood at name or discharge as day his one left last history were had right by this admission date that pain hospital an from p pt normal first has have which but medications up d chest o hours also well given status time care dr after stable course follow started please stitle disease known x continued days two service prior per showed artery m it q medical without namepattern1 glucose past cardiac post heart present unit physical aortic pulmonary i weeks transferred year md allergies edema due t pressure did surgery surgical number condition fluid b found procedure lower prn remained admitted soft hypertension further non coronary rate all placed diagnosis should bilaterally increased three sodium birth abdomen over bilateral aspirin illness social than old sex secondary however primary examination following some positive significant disposition floor take lung room moderate insulin bleeding namepattern4 extremities f upper back use count lasix regular therapy sinus intubated rhythm discharged underwent facility lungs continue job alert off felt sounds hematocrit heparin transfer anterior made distress contrast wound nausea four pulses down diabetes alcohol very extended postoperative followed white both removed diet creatinine hospital3 drip morning note do greater drainage male previous name11 stent intensive oriented worsening oral s1 sent currently catheterization site carotid bypass pattern1 through control weaned office albuterol extubated extremity incision warm controlled tolerated lesion increase amiodarone dictated plavix issues st levofloxacin strength physician patch difficulty infarction night next medquist36 repair increasing minimal dyspnea descending lesions laboratory morphine afebrile though hospital6 venous beta operative lateral outside later lopressor peripheral inferior operating incisions persistent masses denied go myocardial perfused cardiology murmurs bun diuresis went help evening colace lipitor began mellitus smoking nontender wheezing meds occasional ap drain cardiologist cardiothoracic commands wave knee intermittent sternal nondistended intra ambulating bruits rubs pump troponin system nitroglycerin despite lovenox tubes dependent instructed bronchoscopy meq ibuprofen main aggressive puffs came platelet referred palpable dilaudid obese decision heavy pacing stabilized secretions balloon jugular exertion minute neo follows seven leads wires many propofol frequency waves put electrocardiogram attempt paroxysmal begun grafting stenting little ten distention half treat diabetic depressions reflux coarse angina index orthopnea hepatosplenomegaly diameter diffusely urinalysis ready zantac afternoon poorly transdermal gastroesophageal borderline synephrine hemodynamic occurred quite lead refer mobility dominant moved encouraged codeine wire standpoint circumflex includes activities extremely coughing nocturnal unstable allergic stability uneventful toilet observation pole enteric blockers 1l drips incentive coated nicotine sternum weaning dye pericarditis inhalers inversions dobutamine kcl nexium packs burning participate serosanguinous complain lyme exercises spent noninsulin restenosis amaryl cks uncooperative spirometry urination isordil disabled rhonchorous nitrates build concurrently marginally'
==1)[0]) lbs.map_ids(torch.where(y
(#21) ['250.00','33.23','33.24','36.13','36.15','36.19','37.22','39.61','39.64','401.9'...]
= info.lbs_chunked() dls
assert isinstance(dls[0], TfmdDL)
len(dls), np.ceil(len(lbs)/200))
test_eq(len(dls[0]), np.ceil(len(dsets)/bs)) # drop_last is False
test_eq(# test to prove that the labels for each data point is split across multiple dataloaders
= torch.cat([yb[0] for dl in dls for _,yb in itertools.islice(dl, 1)])
lbs_0 = y.to(default_device())
y test_eq(lbs_0, y)
Now let’s compute the joint_pmf
table we had seen earlier.
= info.joint_pmf() p_TL
CPU times: user 8.73 s, sys: 2.58 s, total: 11.3 s
Wall time: 13.6 s
2, 2)) test_eq(p_TL.shape, (info.toksize, info.lblsize,
Technicality: p_TL
is not really the joint pmf (yes, I lied before!) but contains all the information needed to compute the joint pmf p_TxL
and mutual info gain I_TL
. This computation is going to be comnputed by compute
:
= info.compute() p_T, p_L, p_TxL, H_T, H_L, I_TL
CPU times: user 751 ms, sys: 253 ms, total: 1 s
Wall time: 2.41 s
All this while if you have been working with the sampled dataset you can continue to do so for the rest of this notebook. But if you want a real feel of how things look, at this point you can load the pregenerated p_TL
and (p_T, p_L, p_TxL, H_T, H_L, I_TL)
for the full dataset which untar_xxx
downloaded:
# print('\n'.join(L(source.glob("**/*.pkl")).map(str)))
# or better yet
!tree -sh -P "*.pkl" {source}
/home/deb/.xcube/data/mimic3_l2r
├── [1.2M] code_desc.pkl
├── [9.5G] info.pkl
├── [3.8G] mimic3-9k_tok_lbl_info.pkl
├── [7.6G] p_TL.pkl
└── [7.6G] trn_val_split.pkl
0 directories, 5 files
# %%time
= torch.load(source/'p_TL.pkl', map_location=torch.device('cpu'))
p_TL = torch.load(source/'info.pkl', map_location=torch.device('cpu')) p_T, p_L, p_TxL, H_T, H_L, I_TL
Make sure there aren’t any of those pesky nans or negs:
def test_nanegs(*args):
for o in args:
= o.isnan().all() # check for nans
has_nans = not torch.where(o>=0, True, False).all()
has_negs if has_nans: raise Exception(f"{namestr(o, globals())[0]} has nans")
if has_negs: raise Exception(f"{namestr(o, globals())[0]} has negs")
=(p_T, p_L, p_TxL, H_T, H_L, I_TL), contains='I_TL has negs') test_fail(test_nanegs, args
Theoretically, Mutual-Info as defined here is suposed to be non-negative (can be proved by tossing in Jensen). But, practically, it turns out I_TL
has some negs because we distorted the p_TL
and p_TxL
with eps
in the I_TL
computation.
10, largest=False) torch.topk(I_TL.flatten(),
torch.return_types.topk(
values=TensorMultiCategory([-1.9016e-07, -1.8314e-07, -1.8314e-07, -1.7385e-07,
-1.7277e-07, -1.7277e-07, -1.6798e-07, -1.6798e-07,
-1.6798e-07, -1.6767e-07]),
indices=TensorMultiCategory([22423614, 2735913, 2731838, 1911099, 6393113, 6389159,
6695355, 6695018, 6693073, 32253137]))
= torch.where(I_TL < 0, True, False).sum().item()
howmany = torch.where(I_TL < 0, I_TL, I_TL.new_zeros(I_TL.shape))
negs sum()/howmany negs.
TensorMultiCategory(-3.9054e-08)
Those negs on an avg are pretty close to zero. So we need not worry. Let’s roll!
2, 2))
test_eq(p_TL.shape, (info.toksize, info.lblsize, 2, 1))
test_eq(p_T.shape, (info.toksize, 1, 2))
test_eq(p_L.shape, (info.lblsize, 2, 2))
test_eq(p_TxL.shape, (info.toksize, info.lblsize,
test_eq(H_T.shape, [info.toksize])
test_eq(H_L.shape, [info.lblsize]) test_eq(I_TL.shape, (info.toksize, info.lblsize))
= I_TL.new_empty(1).fill_(1e-15)
eps = I_TL/(H_L + eps)
info_lbl_entropy = I_TL/(H_T.unsqueeze(-1) + H_L.unsqueeze(0) - I_TL + eps)
info_jaccard assert not info_lbl_entropy.isnan().all(); assert not info_jaccard.isnan().all()
= {'toks': toks, 'lbs': lbs, 'mut_info_lbl_entropy': info_lbl_entropy, 'mutual_info_jaccard': info_jaccard} l2r_bootstrap
l2r_bootstrap
for the full dataset was downloaded by untar_xxx
in boot_path
. You can load it up in the following cell. l2r_bootstrap
will be used to bootstrap our learning-to-rank model.
Save those Mutual Information Gain values
# l2r_bootstrap = torch.load(boot_path)
# l2r_bootstrap['info_jaccard'] = l2r_bootstrap.pop('mutual_info_jaccard')
# globals().update(l2r_bootstrap)
# info_jaccard.shape
Let’s take a look at the Mutual Information Gain (I_TL
) for each of the labels:
with tempfile.TemporaryDirectory() as tmpdir:
= (p_TL, p_T, p_L, info_jaccard, H_T, H_L)
args = {'k':10, 'save_as': Path(tmpdir)/'mut_info_jaccard.ft'}
kwargs = info.show(*args, **kwargs)
df_info assert (Path(tmpdir)/'mut_info_jaccard.ft').exists()
df_info.head()
label | freq | prob | entropy | description | top-k (token, prob, entropy, joint, info) | |
---|---|---|---|---|---|---|
0 | 008.45 | 17 | 0.022849 | 0.108931 | Intestinal infection due to clostridium difficile | [['difficile' '0.05107527' '0.20166922' '0.014784946' '0.11373689']\n ['cdiff' '0.010752688' '0.05943229' '0.005376344' '0.088108845']\n ['loosely' '0.002688172' '0.018595558' '0.002688172' '0.08804317']\n ['reformatted' '0.00672043' '0.04031744' '0.004032258' '0.08063226']\n ['colitis' '0.05913979' '0.22459403' '0.01344086' '0.07943697']\n ['flagyl' '0.17069893' '0.45699275' '0.021505376' '0.064542644']\n ['enteritis' '0.004032258' '0.026255678' '0.002688172' '0.061064955']\n ['ogt' '0.004032258' '0.026255678' '0.002688172' '0.061064955']\n ['retardation' '0.004032258' '0.026255678' '0.00... |
1 | 008.8 | 2 | 0.002688 | 0.018596 | Intestinal infection due to other organism, not elsewhere classified | [['vasotec' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['gastroenteritis' '0.010752688' '0.05943235' '0.002688172' '0.19164896']\n ['tachyarrhythmia' '0.004032258' '0.026255678' '0.001344086'\n '0.14864387']\n ['bumex' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['rta' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['neoral' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['probenecid' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['electronics' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['cauterized' '0.004032258... |
2 | 009.0 | 1 | 0.001344 | 0.010230 | Infectious colitis, enteritis, and gastroenteritis | [['presacral' '0.001344086' '0.010230334' '0.001344086' '0.9999958']\n ['vibrio' '0.001344086' '0.010230334' '0.001344086' '0.9999958']\n ['yersinia' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['ova' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['parasites' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['resucitation' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['exlap' '0.004032258' '0.026255678' '0.001344086' '0.26589572']\n ['tenting' '0.004032258' '0.026255678' '0.001344086' '0.26589572']\n ['adhesiolysis' '0.004032258' '0.026255678... |
3 | 009.1 | 2 | 0.002688 | 0.018596 | Colitis, enteritis, and gastroenteritis of presumed infectious origin | [['44yf' '0.001344086' '0.010230334' '0.001344086' '0.40896738']\n ['ischioanal' '0.001344086' '0.010230334' '0.001344086' '0.40896738']\n ['perianal' '0.001344086' '0.010230334' '0.001344086' '0.40896738']\n ['paraplegia' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['hunger' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['intrathecal' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['paraplegic' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['vicarious' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['spasticity' '0.002688172' '0.01859... |
4 | 031.0 | 1 | 0.001344 | 0.010230 | Pulmonary diseases due to other mycobacteria | [['gist' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['disrupted' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['77f' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['eroding' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['vaginitis' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['circumscribed' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['discern' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['inseparable' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['pearls' '0.002688172' '0.018595558... |
Let’s look at those Mutual-Information Gain values:
= (df_info.freq>50) & (df_info.freq<150)
mask # with pd.option_context('display.max_colwidth', 100):
# pd.reset_option('all')
= df_info[mask].reset_index(drop=True)
df_info len(df_info)
29
The dataframe below shows the top 10 tokens (based on the mutual-info-gain values) for labels are rare (freq between 50 and 150). Feel free to ChatGPT the label descriptions and the tokens to find out if we’re able to find the needle in a haystack.
'display.max_colwidth', None)
pd.set_option( df_info.head()
label | freq | prob | entropy | description | top-k (token, prob, entropy, joint, info) | |
---|---|---|---|---|---|---|
0 | 038.9 | 52 | 0.069892 | 0.253361 | Unspecified septicemia | [['pressors' '0.13037634' '0.38710147' '0.041666668' '0.079481095']\n ['expired' '0.10215054' '0.32978266' '0.034946237' '0.073719725']\n ['septic' '0.06586021' '0.24279651' '0.024193548' '0.05889168']\n ['spectrum' '0.061827958' '0.23196787' '0.021505376' '0.04939346']\n ['levophed' '0.06989247' '0.25336072' '0.022849463' '0.04748282']\n ['sepsis' '0.18548387' '0.47960785' '0.041666668' '0.04552333']\n ['tpn' '0.0483871' '0.1937385' '0.017473118' '0.04391181']\n ['lactate' '0.26478493' '0.5780026' '0.049731184' '0.041423406']\n ['rescusitated' '0.004032258' '0.026255678' '0.004032258' '0.04032902']\n ['broad' '0.07123656' '0.25682586' '0.021505376' '0.040083304']] |
1 | 244.9 | 71 | 0.095430 | 0.314924 | Unspecified hypothyroidism | [['hypothyroidism' '0.10887097' '0.34414828' '0.083333336' '0.4098433']\n ['levothyroxine' '0.10752688' '0.34131324' '0.07392473' '0.28809547']\n ['synthroid' '0.04973118' '0.19772297' '0.033602152' '0.11988581']\n ['levoxyl' '0.018817205' '0.093399525' '0.016129032' '0.08422138']\n ['hypothyroid' '0.014784946' '0.07698107' '0.01344086' '0.07722074']\n ['mcg' '0.36021507' '0.6535418' '0.076612905' '0.047165737']\n ['88mcg' '0.005376344' '0.03345727' '0.005376344' '0.038052656']\n ['cystitis' '0.004032258' '0.026255678' '0.004032258' '0.028801844']\n ['kyphotic' '0.004032258' '0.026255678' '0.004032258' '0.028801844']\n ['cvat' '0.004032258' '0.026255678' '0.004032258' '0.028801844']] |
2 | 250.00 | 115 | 0.154570 | 0.430555 | type II diabetes mellitus [non-insulin dependent type] [NIDDM type] [adult-onset type] or unspecified type, not stated as uncontrolled, without mention of complication | [['diabetes' '0.2876344' '0.60002' '0.1155914' '0.09059631']\n ['metformin' '0.0672043' '0.24634655' '0.045698926' '0.08375815']\n ['glyburide' '0.037634406' '0.1603519' '0.028225806' '0.063347906']\n ['dm' '0.14784946' '0.4189606' '0.06317204' '0.048482798']\n ['mellitus' '0.14919354' '0.42130768' '0.061827958' '0.044503253']\n ['noninsulin' '0.010752688' '0.05943235' '0.010752688' '0.043445654']\n ['tricor' '0.00672043' '0.04031744' '0.00672043' '0.027659079']\n ['avandia' '0.012096774' '0.06542839' '0.0094086025' '0.024442347']\n ['insulin' '0.2795699' '0.59254664' '0.08064516' '0.024319947']\n ['glipizide' '0.021505376' '0.103841364' '0.01344086' '0.024206813']] |
3 | 272.0 | 91 | 0.122312 | 0.371506 | Pure hypercholesterolemia | [['hypercholesterolemia' '0.13978495' '0.40457296' '0.07795699'\n '0.14934917']\n ['lipitor' '0.17204301' '0.45911095' '0.049731184' '0.027298862']\n ['aspirin' '0.54569894' '0.68896455' '0.10080645' '0.022963593']\n ['crestor' '0.016129032' '0.08256498' '0.0094086025' '0.022421718']\n ['carotids' '0.02688172' '0.12372976' '0.012096774' '0.018986586']\n ['nonreactive' '0.08736559' '0.29640004' '0.0' '0.018241761']\n ['gallop' '0.010752688' '0.05943235' '0.00672043' '0.018127378']\n ['crossclamp' '0.010752688' '0.05943235' '0.00672043' '0.018127378']\n ['palate' '0.086021505' '0.29323512' '0.0' '0.018028865']\n ['mrs' '0.037634406' '0.1603519' '0.014784946' '0.01788708']] |
4 | 272.4 | 131 | 0.176075 | 0.465390 | Other and unspecified hyperlipidemia | [['hyperlipidemia' '0.17876343' '0.46951324' '0.11155914' '0.14867449']\n ['dyslipidemia' '0.0672043' '0.24634655' '0.04032258' '0.048867557']\n ['medquist36' '0.28225806' '0.59507334' '0.004032258' '0.048413806']\n ['brief' '0.75' '0.56233513' '0.1733871' '0.045796935']\n ['invasive' '0.72983867' '0.5834192' '0.17204301' '0.045730278']\n ['major' '0.7338709' '0.5793706' '0.17204301' '0.044849273']\n ['job' '0.29435483' '0.606005' '0.00672043' '0.04331313']\n ['attending' '0.74596775' '0.56672186' '0.17204301' '0.042243805']\n ['dictated' '0.29435483' '0.606005' '0.008064516' '0.039907183']\n ['exam' '0.77956986' '0.5274518' '0.1733871' '0.03951623']] |
# pd.reset_option('all')
# df_info.to_excel('jaccard.xls', index=False)
Scratchpad
from fastai.data.transforms import *
= untar_xxx(XURLs.MIMIC3_L2R)
source = source/'mimic3-9k_tok_lbl_info.pkl'
boot_path assert boot_path.exists()
= torch.load(boot_path)
l2r_bootstrap = mapt(l2r_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
toks, lbs, info print(toks[-2:]) # last two places has 'xxfake'
# toks = CategoryMap(toks, sort=False)
= load_pickle(source/'code_desc.pkl')
lbs_des assert isinstance(lbs_des, dict)
len(toks), len(lbs))) # last two places has 'xxfake' test_eq(info.shape, (
['xxfake', 'xxfake']
89:99]), lbs[:10] L(toks[
((#10) ['course','follow','after','disease','stitle','needed','known','capsule','refills','started'],
(#10) ['003.0','003.1','003.8','003.9','004.1','004.8','004.9','005.1','005.81','005.9'])
= ['008.45', '009.0', '244.9', '250.00'] icd_codes
= dict(zip(icd_codes, mapt(lbs_des.get, icd_codes)))
d= pd.DataFrame(d, index=range(1))
df_des
= lbs.map_objs(icd_codes)
lbs_idxs = info[:, lbs_idxs].topk(dim=0, k=10, largest=True).values.cpu()
top_infos = info[:, lbs_idxs].topk(dim=0, k=10, largest=True).indices.cpu().long()
top_idxs = array(toks).astype(str)
toks # toks.map_ids(tok_idxs[:, 0])
= pd.DataFrame(toks[top_idxs], columns=icd_codes)
df_toks = pd.DataFrame(top_infos, columns=icd_codes) df_infos
display(df_des, df_toks, df_infos)
008.45 | 009.0 | 244.9 | 250.00 | |
---|---|---|---|---|
0 | Intestinal infection due to clostridium difficile | Infectious colitis, enteritis, and gastroenteritis | Unspecified hypothyroidism | type II diabetes mellitus [non-insulin dependent type] [NIDDM type] [adult-onset type] or unspecified type, not stated as uncontrolled, without mention of complication |
008.45 | 009.0 | 244.9 | 250.00 | |
---|---|---|---|---|
0 | difficile | ileoloop | hypothyroidism | diabetes |
1 | colitis | entercort | levothyroxine | metformin |
2 | diff | coumidin | synthroid | mellitus |
3 | clostridium | 33u | hypothyroid | dm |
4 | metronidazole | 8x | mcg | insulin |
5 | flagyl | proctocolitis | levoxyl | glyburide |
6 | toxin | chux | tsh | glipizide |
7 | cdiff | metronidzole | t4 | dm2 |
8 | megacolon | bayview | 88mcg | sliding |
9 | feces | 117bpm | 50mcg | scale |
008.45 | 009.0 | 244.9 | 250.00 | |
---|---|---|---|---|
0 | 0.121274 | 0.027056 | 0.353851 | 0.089817 |
1 | 0.097799 | 0.027056 | 0.319011 | 0.085321 |
2 | 0.093017 | 0.027056 | 0.093766 | 0.063133 |
3 | 0.086582 | 0.027056 | 0.076361 | 0.052869 |
4 | 0.064245 | 0.023884 | 0.063794 | 0.043760 |
5 | 0.061297 | 0.023884 | 0.052179 | 0.040929 |
6 | 0.057576 | 0.023737 | 0.020855 | 0.040014 |
7 | 0.042272 | 0.023737 | 0.014761 | 0.033789 |
8 | 0.032625 | 0.023737 | 0.013494 | 0.030756 |
9 | 0.026234 | 0.023737 | 0.012336 | 0.023395 |
= load_pickle(Path.cwd()/'tmp/models/mimic3-9k_clas_full_vocab.pkl') vocab
= L(toks, use_list=True) toks
The tokens which are there in the xml vocab but we do not have any ‘info’ on:
= L(set(vocab[0]).difference(set(toks)))
not_found not_found
(#20) ['foi','luteinizing','dobhoof','theses','q2day','promiscuity','dissension','sharpio','mhc','remiained'...]
lambda : toks.index('unrmarkable'), contains='is not in list') test_fail(
The tokens which we have info about but were not present in the xml vocab
set(toks).difference(vocab[0])
set()
Thankfully, we have info
for all the labels in the xml vocab:
1], lbs) test_shuffled(vocab[
Now we need to create a mapping between the indices of the xml vocab and the information gain vocab:
def xml2info(xml_vocab, info_vocab):
"Creates a mapping between the indices of the xml vocab and the information-gain vocab"
= {i: info_vocab.index(o) if o in info_vocab else np.inf for i,o in enumerate(xml_vocab)}
xml2info = [o for o in xml2info if xml2info[o] is np.inf]
xml2info_notfnd return xml2info, xml2info_notfnd
= xml2info(vocab[0], toks) toks_xml2info, toks_notfnd
= set(toks_xml2info).difference(set(toks_notfnd))
toks_found 0])[toks_notfnd], not_found)
test_shuffled(array(vocab[= np.random.choice(array(L(toks_found)), size=10)
some_xml_idxs = array(vocab[0])[some_xml_idxs]
some_xml_toks = L(map(toks_xml2info.get, some_xml_idxs))
corres_info_idxs = array(toks)[corres_info_idxs]
corres_info_toks assert all_equal(some_xml_toks, corres_info_toks)
= xml2info(L(vocab[1]), L(lbs)) lbs_xml2info, lbs_notfnd
= set(lbs_xml2info).difference(set(lbs_notfnd))
lbs_found = np.random.choice(array(L(lbs_found)), size=10)
some_xml_idxs = array(vocab[1])[some_xml_idxs]
some_xml_lbs = L(map(lbs_xml2info.get, some_xml_idxs))
corres_info_idxs = array(lbs)[corres_info_idxs]
corres_info_lbs assert all_equal(some_xml_lbs, corres_info_lbs)