Boot L2R

Bootstrapping a learning-to-rank model
! [ -e /content ] && pip install -Uqq xcube # upgrade xcube on colab
from fastai.data.core import *
from xcube.l2r.all import *

Make sure we have that “beast”:

ic(torch.cuda.get_device_name(default_device()));
test_eq(torch.cuda.get_device_name(0), torch.cuda.get_device_name(default_device()))
test_eq(default_device(), torch.device(0))
print(f"GPU memory = {torch.cuda.get_device_properties(default_device()).total_memory/1024**3}GB")
ic| torch.cuda.get_device_name(default_device()): 'Quadro RTX 8000'
GPU memory = 44.99969482421875GB

Setting some environment variables:

# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

Setting defaults for pandas and matplotlib:

# Set the default figure size
plt.rcParams["figure.figsize"] = (8, 4)
# Set pandas column width
pd.set_option('display.max_colwidth', None)

Altering some default jupyter settings:

from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = "last" # "all"

In this tutorial we will find a needle in the haystack with mutual infomation gain:

Mutual-Information Computation

source = untar_xxx(XURLs.MIMIC3_L2R)
(#11) [Path('/home/deb/.xcube/data/mimic3_l2r/info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_descriptions.csv'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_tok_lbl_info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_desc.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/p_TL.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/trn_val_split.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_tok.ft'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_lbl.ft'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k.csv'),Path('/home/deb/.xcube/data/mimic3_l2r/scored_tokens.pth')...]
data = source/'mimic3-9k.csv'
df = pd.read_csv(data,
                 header=0,
                 names=['subject_id', 'hadm_id', 'text', 'labels', 'length', 'is_valid'],
                 dtype={'subject_id': str, 'hadm_id': str, 'text': str, 'labels': str, 'length': np.int64, 'is_valid': bool})
df[['text', 'labels']] = df[['text', 'labels']].astype(str)
len(df)
52726
df.head(3)
subject_id hadm_id text labels length is_valid
0 86006 111912 admission date discharge date date of birth sex f service surgery allergies patient recorded as having no known allergies to drugs attending first name3 lf chief complaint 60f on coumadin was found slightly drowsy tonight then fell down stairs paramedic found her unconscious and she was intubated w o any medication head ct shows multiple iph transferred to hospital1 for further eval major surgical or invasive procedure none past medical history her medical history is significant for hypertension osteoarthritis involving bilateral knee joints with a dependence on cane for ambulation chronic... 801.35;348.4;805.06;807.01;998.30;707.24;E880.9;427.31;414.01;401.9;V58.61;V43.64;707.00;E878.1;96.71 230 False
1 85950 189769 admission date discharge date service neurosurgery allergies sulfa sulfonamides attending first name3 lf chief complaint cc cc contact info major surgical or invasive procedure none history of present illness hpi 88m who lives with family had fall yesterday today had decline in mental status ems called pt was unresponsive on arrival went to osh head ct showed large r sdh pt was intubated at osh and transferred to hospital1 for further care past medical history cad s p mi in s p cabg in ventricular aneurysm at that time cath in with occluded rca unable to intervene chf reported ef 1st degre... 852.25;E888.9;403.90;585.9;250.00;414.00;V45.81;96.71 304 False
2 88025 180431 admission date discharge date date of birth sex f service surgery allergies no known allergies adverse drug reactions attending first name3 lf chief complaint s p fall major surgical or invasive procedure none history of present illness 45f etoh s p fall from window at feet found ambulating and slurring speech on scene intubated en route for declining mental status in the er the patient was found to be bradycardic to the s with bp of systolic she was given atropine dilantin and was started on saline past medical history unknown social history unknown family history unknown physical exam ex... 518.81;348.4;348.82;801.25;427.89;E882;V49.86;305.00;96.71;38.93 359 False

The file 'code_desc.pkl' contains a short description for the labels.

# with open(source/'code_desc.pkl', 'rb') as f: lbs_desc = pickle.load(f)
lbs_desc = load_pickle(source/'code_desc.pkl')
assert isinstance(lbs_desc, dict)
test_eq(mapt(lbs_desc.get, ['00.93', '427.31', '00.09']), ('Transplant from cadaver',
                                                          'Atrial fibrillation',
                                                          'Other therapeutic ultrasound'))

Note that performing some computations in this notebook on the full dataset is going to take a lot of time. But don’t worry untar_xxx has already downloaded everything you need. But you can still run the following cells if you want to generate everything from scratch. Preferably, run the following cells on a sampled dataset for quick iterations.

Run the cell below only if you want to sample from the full dataset to create a tiny dataset for the purpose of quick iterations.

Technical Point: If we want to sample to perform quick iterations, we need to make sure the number of data points in the sample is a multiple of bs. So that we do not have to do a drop_last=True while creating the Dataloaders. This is because we are about to do some probability computations, and dropping data points is not a good idea as probabilities would not sum to 1.

bs = 8
cut = len(df) - len(df)%bs
df = df[:cut]
len(df)
52720
_arr = np.arange(0, len(df), bs)
# mask = (_arr > 4000) & (_arr < 5000)
mask = (_arr > 500) & (_arr < 1000)
_n = np.random.choice(_arr[mask], 1)
df = df.sample(n=_n, random_state=89, ignore_index=True)
len(df)
744
df.head(3)
subject_id hadm_id text labels length is_valid
0 2258 139169 admission date discharge date date of birth sex m service cardiothoracic surgery history of present illness the patient is a year old male with a past medical history significant for poorly controlled diabetes mellitus and hypertension as well as known coronary disease and a previous non q myocardial infarction and right coronary artery stenting in he was admitted to an outside hospital on the day prior to admission with unstable angina and found to have borderline positive troponin hypertension and st depressions in the lateral lead he was given aspirin nitrates beta blockers morphine and... 414.01;998.31;411.1;599.0;412;V45.82;250.00;401.9;530.81;36.13;37.22;36.15;36.19;39.61;39.64;88.56;88.53;33.23;96.56;33.24;78.41 1271 False
1 41217 161582 admission date discharge date date of birth sex m service medicine allergies no known allergies adverse drug reactions attending first name3 lf chief complaint new diagnosis of scc of base of tongue major surgical or invasive procedure egd w biopsy history of present illness yo man with h o cad heavy smoking and new diagnosis of scc of base of tongue with lymph node involvement pt was referred to dr last name stitle ent in for a rt neck mass at that time a cm rt cervical lymph node was palpated and fiberoptic laryngoscopy showed a cm rt base of tongue mass a ct and biopsy were recommended ... 141.0;507.0;196.0;293.0;519.09;786.30;286.9;427.89;790.29;276.52;414.01;338.3;280.0;272.0;412;V69.4;V15.82;V45.82;V66.7;E879.8;E932.0;31.42;25.01;42.23;43.11;96.6;38.93;99.25;38.93 2743 False
2 30204 172114 admission date discharge date date of birth sex f service medicine allergies etomidate norpace quinidine demerol penicillins lipitor attending doctor first name chief complaint cardiac tamponade s p pulmonary vein isolation major surgical or invasive procedure attempted pulmonary vein isolation pericardiocentesis history of present illness year old woman with a long history of paroxysmal atrial fibrillation refractory to mulitple pharmacologic interventions and multiple cardioversions who presents to the ccu with cardiac tamponade s p pulmonary vein isolation procedure past medical history... 427.31;998.2;423.3;423.9;573.0;276.6;E878.8;37.34;37.27;37.0;37.21 1764 False

Mutual Information

Pictorial representation of simple neural network

The mutual information of two jointly discrete random variables X and Y is calculated as a double sum:

\[I(T;L) = \sum_{l \in \mathcal{L}} \sum_{t in \mathcal{T}} P_{(T,L)}(t,l) \log \Bigg(\frac{P_{(T,L)}(t,l)}{P_T(t) P_L(l)} \Bigg)\]

where \(P_{(T,L)}\) is the joint probability mass function of \(T\) and \(L\), and \(P_T\) and \(P_L\) are the marginal probability mass fucntions of \(T\) and \(L\) respectively. To compute \(I\), the only quantity we need to compute is the joint pmf \(P_{(T,L)}\), as the marginal pmfs can be computed from the joint pmf.

With regard to implementation, \(P_{(T,L)}\) can be thought of as a 2x2 tensor as shown below:

p_TL = pd.DataFrame(0, columns=['t', 'not t'], index=['lbl', 'not lbl'])
p_TL
t not t
lbl 0 0
not lbl 0 0

…and we need to compute this \(P_{(T,L)}\) for every token-label pair. In other words, we need to fill in the joint dataframe shown below. Note that each cell in joint dataframe can be thought of to be further subdivided into a 2x2 grid containing the corresponding p_TL.

bs, chnk_sz = 8, 200
info = MutualInfoGain(df, bs=bs, chnk_sz=chnk_sz, lbs_desc=source/'code_desc.pkl') # provide lbs_desc if you have it
dsets = info.onehotify()
CPU times: user 597 ms, sys: 244 ms, total: 842 ms
Wall time: 4.21 s
toks, lbs = dsets.vocab
L(toks), L(lbs), len(toks)*len(lbs)
((#10632) ['xxunk','xxpad','xxbos','xxeos','xxfld','xxrep','xxwrep','xxup','xxmaj','the'...],
 (#2150) ['008.45','008.8','009.0','009.1','031.0','031.2','038.0','038.10','038.11','038.19'...],
 22858800)
joint = pd.DataFrame(0, columns=range(len(lbs)), index=range(len(toks)))
joint.index.name = 'toks (T)'
joint.columns.name = 'lbs (L)'
joint
lbs (L) 0 1 2 3 4 5 6 7 8 9 ... 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149
toks (T)
0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
4 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
10627 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
10628 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
10629 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
10630 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
10631 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

10632 rows × 2150 columns

We can perform tensorized computation if we think of p_TL as a 4 dim tensor of size (len(toks), len(lbs), 2, 2). Next, to be able to estimate p_TL we just need to iterate over the dataset and for each data point and each token-label pair record the p_TL information in the last two dimension of the tensor p_TL. And, at the end divide by size of the dataset.

Some more implementation details (Skip this if not iterested):

  • We are going to one-hot encode the dataset (both text and labels field in the df). This is done by onehot_dsets
  • For efficieny, in reality we are not going to iterate over the dataset one by one, instead we are going to use a dataloader and perform p_TL computation on a mini-batch.
  • Unless you are doing this in 2035 you probably do not have enogh GPU-RAM to fit the entire p_TL tensor of dimension (len(toks), len(lbs), 2, 2). So we are going to split the lbs dimension into chunks. (Why the lbs dimension and not the toks? Because in XML datsets toks are approximately 60000, but the number of lbs could be really large of the order of millions.) With reagrd to implementation this would mean that instead of one dataloader we would roll with multiple dataloaders. And each dataloader would load the dataset in a way that mini-batches would contain the full one-hot encoding of the text field but only a certain chunk of the one-hot encoded labels field in df. Another way to think about this is that each datapoint, specifically the labels are splitted across multiple dataloaders. This way once we are done iterating over one such dataloader we would have filled a ceratin chunk of the joint dataframe shown above. And we would fill the entire joint only once we are done iterating over all the dataloaders.
x, y = dsets[0]
test_eq(tensor(dsets.tfms[1][2].decode(y)), torch.where(y==1)[0])
test_eq(tensor(dsets.tfms[0][-1].decode(x)), torch.where(x==1)[0])
' '.join(L(toks)[torch.where(x==1)[0]])
'xxunk xxbos the and to of was with a on in for mg no patient is he blood at name or discharge as day his one left last history were had right by this admission date that pain hospital an from p pt normal first has have which but medications up d chest o hours also well given status time care dr after stable course follow started please stitle disease known x continued days two service prior per showed artery m it q medical without namepattern1 glucose past cardiac post heart present unit physical aortic pulmonary i weeks transferred year md allergies edema due t pressure did surgery surgical number condition fluid b found procedure lower prn remained admitted soft hypertension further non coronary rate all placed diagnosis should bilaterally increased three sodium birth abdomen over bilateral aspirin illness social than old sex secondary however primary examination following some positive significant disposition floor take lung room moderate insulin bleeding namepattern4 extremities f upper back use count lasix regular therapy sinus intubated rhythm discharged underwent facility lungs continue job alert off felt sounds hematocrit heparin transfer anterior made distress contrast wound nausea four pulses down diabetes alcohol very extended postoperative followed white both removed diet creatinine hospital3 drip morning note do greater drainage male previous name11 stent intensive oriented worsening oral s1 sent currently catheterization site carotid bypass pattern1 through control weaned office albuterol extubated extremity incision warm controlled tolerated lesion increase amiodarone dictated plavix issues st levofloxacin strength physician patch difficulty infarction night next medquist36 repair increasing minimal dyspnea descending lesions laboratory morphine afebrile though hospital6 venous beta operative lateral outside later lopressor peripheral inferior operating incisions persistent masses denied go myocardial perfused cardiology murmurs bun diuresis went help evening colace lipitor began mellitus smoking nontender wheezing meds occasional ap drain cardiologist cardiothoracic commands wave knee intermittent sternal nondistended intra ambulating bruits rubs pump troponin system nitroglycerin despite lovenox tubes dependent instructed bronchoscopy meq ibuprofen main aggressive puffs came platelet referred palpable dilaudid obese decision heavy pacing stabilized secretions balloon jugular exertion minute neo follows seven leads wires many propofol frequency waves put electrocardiogram attempt paroxysmal begun grafting stenting little ten distention half treat diabetic depressions reflux coarse angina index orthopnea hepatosplenomegaly diameter diffusely urinalysis ready zantac afternoon poorly transdermal gastroesophageal borderline synephrine hemodynamic occurred quite lead refer mobility dominant moved encouraged codeine wire standpoint circumflex includes activities extremely coughing nocturnal unstable allergic stability uneventful toilet observation pole enteric blockers 1l drips incentive coated nicotine sternum weaning dye pericarditis inhalers inversions dobutamine kcl nexium packs burning participate serosanguinous complain lyme exercises spent noninsulin restenosis amaryl cks uncooperative spirometry urination isordil disabled rhonchorous nitrates build concurrently marginally'
lbs.map_ids(torch.where(y==1)[0])
(#21) ['250.00','33.23','33.24','36.13','36.15','36.19','37.22','39.61','39.64','401.9'...]
dls = info.lbs_chunked()
assert isinstance(dls[0], TfmdDL)
test_eq(len(dls),  np.ceil(len(lbs)/200))
test_eq(len(dls[0]), np.ceil(len(dsets)/bs)) # drop_last is False
# test to prove that the labels for each data point is split across multiple dataloaders
lbs_0 = torch.cat([yb[0] for dl in dls for _,yb in itertools.islice(dl, 1)])
y = y.to(default_device())
test_eq(lbs_0, y)

Now let’s compute the joint_pmf table we had seen earlier.

p_TL = info.joint_pmf()
100.00% [11/11 00:13<00:00]
CPU times: user 8.73 s, sys: 2.58 s, total: 11.3 s
Wall time: 13.6 s
test_eq(p_TL.shape, (info.toksize, info.lblsize, 2, 2))

Technicality: p_TL is not really the joint pmf (yes, I lied before!) but contains all the information needed to compute the joint pmf p_TxL and mutual info gain I_TL. This computation is going to be comnputed by compute:

p_T, p_L, p_TxL, H_T, H_L, I_TL = info.compute()
CPU times: user 751 ms, sys: 253 ms, total: 1 s
Wall time: 2.41 s

All this while if you have been working with the sampled dataset you can continue to do so for the rest of this notebook. But if you want a real feel of how things look, at this point you can load the pregenerated p_TL and (p_T, p_L, p_TxL, H_T, H_L, I_TL) for the full dataset which untar_xxx downloaded:

# print('\n'.join(L(source.glob("**/*.pkl")).map(str)))
# or better yet
!tree -sh -P "*.pkl" {source}
/home/deb/.xcube/data/mimic3_l2r
├── [1.2M]  code_desc.pkl
├── [9.5G]  info.pkl
├── [3.8G]  mimic3-9k_tok_lbl_info.pkl
├── [7.6G]  p_TL.pkl
└── [7.6G]  trn_val_split.pkl

0 directories, 5 files
# %%time 
p_TL = torch.load(source/'p_TL.pkl', map_location=torch.device('cpu'))
p_T, p_L, p_TxL, H_T, H_L, I_TL = torch.load(source/'info.pkl', map_location=torch.device('cpu'))

Make sure there aren’t any of those pesky nans or negs:

def test_nanegs(*args):
    for o in args:
        has_nans = o.isnan().all() # check for nans
        has_negs = not torch.where(o>=0, True, False).all()
        if has_nans: raise Exception(f"{namestr(o, globals())[0]} has nans")
        if has_negs: raise Exception(f"{namestr(o, globals())[0]} has negs")
test_fail(test_nanegs, args=(p_T, p_L, p_TxL, H_T, H_L, I_TL), contains='I_TL has negs')

Theoretically, Mutual-Info as defined here is suposed to be non-negative (can be proved by tossing in Jensen). But, practically, it turns out I_TL has some negs because we distorted the p_TL and p_TxL with eps in the I_TL computation.

torch.topk(I_TL.flatten(), 10, largest=False)
torch.return_types.topk(
values=TensorMultiCategory([-1.9016e-07, -1.8314e-07, -1.8314e-07, -1.7385e-07,
                     -1.7277e-07, -1.7277e-07, -1.6798e-07, -1.6798e-07,
                     -1.6798e-07, -1.6767e-07]),
indices=TensorMultiCategory([22423614,  2735913,  2731838,  1911099,  6393113,  6389159,
                      6695355,  6695018,  6693073, 32253137]))
howmany = torch.where(I_TL < 0, True, False).sum().item()
negs = torch.where(I_TL < 0, I_TL, I_TL.new_zeros(I_TL.shape))
negs.sum()/howmany
TensorMultiCategory(-3.9054e-08)

Those negs on an avg are pretty close to zero. So we need not worry. Let’s roll!

test_eq(p_TL.shape, (info.toksize, info.lblsize, 2, 2))
test_eq(p_T.shape, (info.toksize, 2, 1))
test_eq(p_L.shape, (info.lblsize, 1, 2))
test_eq(p_TxL.shape, (info.toksize, info.lblsize, 2, 2))
test_eq(H_T.shape, [info.toksize])
test_eq(H_L.shape, [info.lblsize])
test_eq(I_TL.shape, (info.toksize, info.lblsize))
eps = I_TL.new_empty(1).fill_(1e-15)
info_lbl_entropy = I_TL/(H_L + eps)
info_jaccard = I_TL/(H_T.unsqueeze(-1) + H_L.unsqueeze(0) - I_TL + eps)
assert not info_lbl_entropy.isnan().all(); assert not info_jaccard.isnan().all()
l2r_bootstrap = {'toks': toks, 'lbs': lbs, 'mut_info_lbl_entropy': info_lbl_entropy, 'mutual_info_jaccard': info_jaccard}

l2r_bootstrap for the full dataset was downloaded by untar_xxx in boot_path. You can load it up in the following cell. l2r_bootstrap will be used to bootstrap our learning-to-rank model.

Save those Mutual Information Gain values

# l2r_bootstrap = torch.load(boot_path)
# l2r_bootstrap['info_jaccard'] = l2r_bootstrap.pop('mutual_info_jaccard')
# globals().update(l2r_bootstrap)
# info_jaccard.shape

Let’s take a look at the Mutual Information Gain (I_TL) for each of the labels:

with tempfile.TemporaryDirectory() as tmpdir:
    args = (p_TL, p_T, p_L, info_jaccard, H_T, H_L)
    kwargs = {'k':10, 'save_as': Path(tmpdir)/'mut_info_jaccard.ft'}
    df_info = info.show(*args, **kwargs)
    assert (Path(tmpdir)/'mut_info_jaccard.ft').exists()
df_info.head()
label freq prob entropy description top-k (token, prob, entropy, joint, info)
0 008.45 17 0.022849 0.108931 Intestinal infection due to clostridium difficile [['difficile' '0.05107527' '0.20166922' '0.014784946' '0.11373689']\n ['cdiff' '0.010752688' '0.05943229' '0.005376344' '0.088108845']\n ['loosely' '0.002688172' '0.018595558' '0.002688172' '0.08804317']\n ['reformatted' '0.00672043' '0.04031744' '0.004032258' '0.08063226']\n ['colitis' '0.05913979' '0.22459403' '0.01344086' '0.07943697']\n ['flagyl' '0.17069893' '0.45699275' '0.021505376' '0.064542644']\n ['enteritis' '0.004032258' '0.026255678' '0.002688172' '0.061064955']\n ['ogt' '0.004032258' '0.026255678' '0.002688172' '0.061064955']\n ['retardation' '0.004032258' '0.026255678' '0.00...
1 008.8 2 0.002688 0.018596 Intestinal infection due to other organism, not elsewhere classified [['vasotec' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['gastroenteritis' '0.010752688' '0.05943235' '0.002688172' '0.19164896']\n ['tachyarrhythmia' '0.004032258' '0.026255678' '0.001344086'\n '0.14864387']\n ['bumex' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['rta' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['neoral' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['probenecid' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['electronics' '0.004032258' '0.026255678' '0.001344086' '0.14864387']\n ['cauterized' '0.004032258...
2 009.0 1 0.001344 0.010230 Infectious colitis, enteritis, and gastroenteritis [['presacral' '0.001344086' '0.010230334' '0.001344086' '0.9999958']\n ['vibrio' '0.001344086' '0.010230334' '0.001344086' '0.9999958']\n ['yersinia' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['ova' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['parasites' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['resucitation' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['exlap' '0.004032258' '0.026255678' '0.001344086' '0.26589572']\n ['tenting' '0.004032258' '0.026255678' '0.001344086' '0.26589572']\n ['adhesiolysis' '0.004032258' '0.026255678...
3 009.1 2 0.002688 0.018596 Colitis, enteritis, and gastroenteritis of presumed infectious origin [['44yf' '0.001344086' '0.010230334' '0.001344086' '0.40896738']\n ['ischioanal' '0.001344086' '0.010230334' '0.001344086' '0.40896738']\n ['perianal' '0.001344086' '0.010230334' '0.001344086' '0.40896738']\n ['paraplegia' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['hunger' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['intrathecal' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['paraplegic' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['vicarious' '0.002688172' '0.018595558' '0.001344086' '0.2120071']\n ['spasticity' '0.002688172' '0.01859...
4 031.0 1 0.001344 0.010230 Pulmonary diseases due to other mycobacteria [['gist' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['disrupted' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['77f' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['eroding' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['vaginitis' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['circumscribed' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['discern' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['inseparable' '0.002688172' '0.018595558' '0.001344086' '0.40896738']\n ['pearls' '0.002688172' '0.018595558...

Let’s look at those Mutual-Information Gain values:

mask = (df_info.freq>50) & (df_info.freq<150)
# with pd.option_context('display.max_colwidth', 100):
# pd.reset_option('all')
df_info = df_info[mask].reset_index(drop=True)
len(df_info)
29

The dataframe below shows the top 10 tokens (based on the mutual-info-gain values) for labels are rare (freq between 50 and 150). Feel free to ChatGPT the label descriptions and the tokens to find out if we’re able to find the needle in a haystack.

pd.set_option('display.max_colwidth', None)
df_info.head()
label freq prob entropy description top-k (token, prob, entropy, joint, info)
0 038.9 52 0.069892 0.253361 Unspecified septicemia [['pressors' '0.13037634' '0.38710147' '0.041666668' '0.079481095']\n ['expired' '0.10215054' '0.32978266' '0.034946237' '0.073719725']\n ['septic' '0.06586021' '0.24279651' '0.024193548' '0.05889168']\n ['spectrum' '0.061827958' '0.23196787' '0.021505376' '0.04939346']\n ['levophed' '0.06989247' '0.25336072' '0.022849463' '0.04748282']\n ['sepsis' '0.18548387' '0.47960785' '0.041666668' '0.04552333']\n ['tpn' '0.0483871' '0.1937385' '0.017473118' '0.04391181']\n ['lactate' '0.26478493' '0.5780026' '0.049731184' '0.041423406']\n ['rescusitated' '0.004032258' '0.026255678' '0.004032258' '0.04032902']\n ['broad' '0.07123656' '0.25682586' '0.021505376' '0.040083304']]
1 244.9 71 0.095430 0.314924 Unspecified hypothyroidism [['hypothyroidism' '0.10887097' '0.34414828' '0.083333336' '0.4098433']\n ['levothyroxine' '0.10752688' '0.34131324' '0.07392473' '0.28809547']\n ['synthroid' '0.04973118' '0.19772297' '0.033602152' '0.11988581']\n ['levoxyl' '0.018817205' '0.093399525' '0.016129032' '0.08422138']\n ['hypothyroid' '0.014784946' '0.07698107' '0.01344086' '0.07722074']\n ['mcg' '0.36021507' '0.6535418' '0.076612905' '0.047165737']\n ['88mcg' '0.005376344' '0.03345727' '0.005376344' '0.038052656']\n ['cystitis' '0.004032258' '0.026255678' '0.004032258' '0.028801844']\n ['kyphotic' '0.004032258' '0.026255678' '0.004032258' '0.028801844']\n ['cvat' '0.004032258' '0.026255678' '0.004032258' '0.028801844']]
2 250.00 115 0.154570 0.430555 type II diabetes mellitus [non-insulin dependent type] [NIDDM type] [adult-onset type] or unspecified type, not stated as uncontrolled, without mention of complication [['diabetes' '0.2876344' '0.60002' '0.1155914' '0.09059631']\n ['metformin' '0.0672043' '0.24634655' '0.045698926' '0.08375815']\n ['glyburide' '0.037634406' '0.1603519' '0.028225806' '0.063347906']\n ['dm' '0.14784946' '0.4189606' '0.06317204' '0.048482798']\n ['mellitus' '0.14919354' '0.42130768' '0.061827958' '0.044503253']\n ['noninsulin' '0.010752688' '0.05943235' '0.010752688' '0.043445654']\n ['tricor' '0.00672043' '0.04031744' '0.00672043' '0.027659079']\n ['avandia' '0.012096774' '0.06542839' '0.0094086025' '0.024442347']\n ['insulin' '0.2795699' '0.59254664' '0.08064516' '0.024319947']\n ['glipizide' '0.021505376' '0.103841364' '0.01344086' '0.024206813']]
3 272.0 91 0.122312 0.371506 Pure hypercholesterolemia [['hypercholesterolemia' '0.13978495' '0.40457296' '0.07795699'\n '0.14934917']\n ['lipitor' '0.17204301' '0.45911095' '0.049731184' '0.027298862']\n ['aspirin' '0.54569894' '0.68896455' '0.10080645' '0.022963593']\n ['crestor' '0.016129032' '0.08256498' '0.0094086025' '0.022421718']\n ['carotids' '0.02688172' '0.12372976' '0.012096774' '0.018986586']\n ['nonreactive' '0.08736559' '0.29640004' '0.0' '0.018241761']\n ['gallop' '0.010752688' '0.05943235' '0.00672043' '0.018127378']\n ['crossclamp' '0.010752688' '0.05943235' '0.00672043' '0.018127378']\n ['palate' '0.086021505' '0.29323512' '0.0' '0.018028865']\n ['mrs' '0.037634406' '0.1603519' '0.014784946' '0.01788708']]
4 272.4 131 0.176075 0.465390 Other and unspecified hyperlipidemia [['hyperlipidemia' '0.17876343' '0.46951324' '0.11155914' '0.14867449']\n ['dyslipidemia' '0.0672043' '0.24634655' '0.04032258' '0.048867557']\n ['medquist36' '0.28225806' '0.59507334' '0.004032258' '0.048413806']\n ['brief' '0.75' '0.56233513' '0.1733871' '0.045796935']\n ['invasive' '0.72983867' '0.5834192' '0.17204301' '0.045730278']\n ['major' '0.7338709' '0.5793706' '0.17204301' '0.044849273']\n ['job' '0.29435483' '0.606005' '0.00672043' '0.04331313']\n ['attending' '0.74596775' '0.56672186' '0.17204301' '0.042243805']\n ['dictated' '0.29435483' '0.606005' '0.008064516' '0.039907183']\n ['exam' '0.77956986' '0.5274518' '0.1733871' '0.03951623']]
# pd.reset_option('all')
# df_info.to_excel('jaccard.xls', index=False)

Scratchpad

from fastai.data.transforms import *
source = untar_xxx(XURLs.MIMIC3_L2R)
boot_path = source/'mimic3-9k_tok_lbl_info.pkl'
assert boot_path.exists()
l2r_bootstrap = torch.load(boot_path)
toks, lbs, info = mapt(l2r_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
print(toks[-2:]) # last two places has 'xxfake'
# toks = CategoryMap(toks, sort=False)
lbs_des = load_pickle(source/'code_desc.pkl')
assert isinstance(lbs_des, dict)
test_eq(info.shape, (len(toks), len(lbs))) # last two places has 'xxfake'
['xxfake', 'xxfake']
L(toks[89:99]), lbs[:10]
((#10) ['course','follow','after','disease','stitle','needed','known','capsule','refills','started'],
 (#10) ['003.0','003.1','003.8','003.9','004.1','004.8','004.9','005.1','005.81','005.9'])
icd_codes = ['008.45', '009.0', '244.9', '250.00']
d= dict(zip(icd_codes, mapt(lbs_des.get, icd_codes)))
df_des = pd.DataFrame(d, index=range(1))

lbs_idxs = lbs.map_objs(icd_codes)
top_infos = info[:, lbs_idxs].topk(dim=0, k=10, largest=True).values.cpu()
top_idxs = info[:, lbs_idxs].topk(dim=0, k=10, largest=True).indices.cpu().long()
toks = array(toks).astype(str)
# toks.map_ids(tok_idxs[:, 0])
df_toks = pd.DataFrame(toks[top_idxs], columns=icd_codes)
df_infos = pd.DataFrame(top_infos, columns=icd_codes)
display(df_des, df_toks, df_infos)
008.45 009.0 244.9 250.00
0 Intestinal infection due to clostridium difficile Infectious colitis, enteritis, and gastroenteritis Unspecified hypothyroidism type II diabetes mellitus [non-insulin dependent type] [NIDDM type] [adult-onset type] or unspecified type, not stated as uncontrolled, without mention of complication
008.45 009.0 244.9 250.00
0 difficile ileoloop hypothyroidism diabetes
1 colitis entercort levothyroxine metformin
2 diff coumidin synthroid mellitus
3 clostridium 33u hypothyroid dm
4 metronidazole 8x mcg insulin
5 flagyl proctocolitis levoxyl glyburide
6 toxin chux tsh glipizide
7 cdiff metronidzole t4 dm2
8 megacolon bayview 88mcg sliding
9 feces 117bpm 50mcg scale
008.45 009.0 244.9 250.00
0 0.121274 0.027056 0.353851 0.089817
1 0.097799 0.027056 0.319011 0.085321
2 0.093017 0.027056 0.093766 0.063133
3 0.086582 0.027056 0.076361 0.052869
4 0.064245 0.023884 0.063794 0.043760
5 0.061297 0.023884 0.052179 0.040929
6 0.057576 0.023737 0.020855 0.040014
7 0.042272 0.023737 0.014761 0.033789
8 0.032625 0.023737 0.013494 0.030756
9 0.026234 0.023737 0.012336 0.023395
vocab = load_pickle(Path.cwd()/'tmp/models/mimic3-9k_clas_full_vocab.pkl')
toks = L(toks, use_list=True)

The tokens which are there in the xml vocab but we do not have any ‘info’ on:

not_found = L(set(vocab[0]).difference(set(toks)))
not_found
(#20) ['foi','luteinizing','dobhoof','theses','q2day','promiscuity','dissension','sharpio','mhc','remiained'...]
test_fail(lambda : toks.index('unrmarkable'), contains='is not in list')

The tokens which we have info about but were not present in the xml vocab

set(toks).difference(vocab[0])
set()

Thankfully, we have info for all the labels in the xml vocab:

test_shuffled(vocab[1], lbs)

Now we need to create a mapping between the indices of the xml vocab and the information gain vocab:

def xml2info(xml_vocab, info_vocab):
    "Creates a mapping between the indices of the xml vocab and the information-gain vocab"
    xml2info = {i: info_vocab.index(o) if o in info_vocab else np.inf  for i,o in enumerate(xml_vocab)}
    xml2info_notfnd = [o for o in xml2info if xml2info[o] is np.inf]
    return xml2info, xml2info_notfnd
toks_xml2info, toks_notfnd = xml2info(vocab[0], toks)
toks_found = set(toks_xml2info).difference(set(toks_notfnd))
test_shuffled(array(vocab[0])[toks_notfnd], not_found)
some_xml_idxs = np.random.choice(array(L(toks_found)), size=10)
some_xml_toks = array(vocab[0])[some_xml_idxs]
corres_info_idxs = L(map(toks_xml2info.get, some_xml_idxs))
corres_info_toks = array(toks)[corres_info_idxs]
assert all_equal(some_xml_toks, corres_info_toks)
lbs_xml2info, lbs_notfnd = xml2info(L(vocab[1]), L(lbs))
lbs_found = set(lbs_xml2info).difference(set(lbs_notfnd))
some_xml_idxs = np.random.choice(array(L(lbs_found)), size=10)
some_xml_lbs = array(vocab[1])[some_xml_idxs]
corres_info_idxs = L(map(lbs_xml2info.get, some_xml_idxs))
corres_info_lbs = array(lbs)[corres_info_idxs]
assert all_equal(some_xml_lbs, corres_info_lbs)