= {'u_weight.weight': torch.randn(3,5),
wgts 'i_weight.weight': torch.randn(4,5),
'u_bias.weight' : torch.randn(3,1),
'i_bias.weight' : torch.randn(4,1)}
= {'token': ['#na#', 'sun', 'moon', 'earth', 'mars'],
collab_vocab 'label': ['#na#', 'a', 'c', 'b']}
= ['a', 'b', 'c']
lbs_vocab = match_collab(wgts.copy(), collab_vocab, lbs_vocab)
new_wgts, missing 0)
test_eq(missing, 'u_weight.weight'], new_wgts['u_weight.weight'])
test_close(wgts['u_bias.weight'], new_wgts['u_bias.weight'])
test_close(wgts[with ExceptionExpected(ex=AssertionError, regex="close"):
'i_weight.weight'][1:], new_wgts['i_weight.weight'])
test_close(wgts['i_bias.weight'][1:], new_wgts['i_bias.weight'])
test_close(wgts[= wgts['i_weight.weight'], new_wgts['i_weight.weight']
old_w, new_w = wgts['i_bias.weight'], new_wgts['i_bias.weight']
old_b, new_b for (old_k,old_v), (new_k, new_v) in zip(wgts.items(), new_wgts.items()):
if old_k.startswith('u'): test_eq(old_v.size(), new_v.size())
else: test_ne(old_v.size(), new_v.size());
# print(f"old: {old_k} = {old_v.size()}, new: {new_k} = {new_v.size()}")
0], old_w[1]); test_eq(new_b[0], old_b[1])
test_eq(new_w[1], old_w[3]); test_eq(new_b[1], old_b[3])
test_eq(new_w[2], old_w[2]); test_eq(new_b[2], old_b[2])
test_eq(new_w[list(old_b[1:].squeeze().numpy()), list(new_b.squeeze().numpy()))
test_shuffled(1:], dim=0)[0], torch.sort(new_b, dim=0)[0])
test_eq(torch.sort(old_b[1:], dim=0)[0], torch.sort(new_w, dim=0)[0]) test_eq(torch.sort(old_w[
Learner for the XML Text application:
Learner
suitable for transfer learning in XML text classification.
The most important function of this module is xmltext_classifier_learner
. This will help you define a Learner
using a pretrained Language Model for the encoder and a pretrained Learning-to-Rank-Model for the decoder. (Tutorial: Coming Soon!). This module is inspired from fastai’s TextLearner based on the paper ULMFit.
Loading label embeddings from a pretrained colab model
match_collab
match_collab (old_wgts:dict, collab_vocab:dict, lbs_vocab:list)
Convert the label embedding in old_wgts
to go from old_vocab
in colab to lbs_vocab
Type | Details | |
---|---|---|
old_wgts | dict | Embedding weights of the colab model |
collab_vocab | dict | Vocabulary of token and label used for colab pre-training |
lbs_vocab | list | Current labels vocabulary |
Returns | dict |
Loading Pretrained Information Gain as Attention
from xcube.l2r.all import *
= untar_xxx(XURLs.MIMIC3)
source_mimic = load_pickle(source_mimic/'mimic3-9k_clas_full_vocab.pkl')
xml_vocab = L(xml_vocab).map(listify) xml_vocab
= untar_xxx(XURLs.MIMIC3_L2R)
source_l2r = join_path_file('mimic3-9k_tok_lbl_info', source_l2r, ext='.pkl')
boot_path = join_path_file('p_L', source_l2r, ext='.pkl')
bias_path = torch.load(boot_path, map_location=default_device())
l2r_bootstrap = torch.load(bias_path, map_location=default_device()) brain_bias
*brain_vocab, brain = mapt(l2r_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
= L(brain_vocab).map(listify)
brain_vocab = brain_vocab
toks, lbs print(f"last two places in brain vocab has {toks[-2:]}")
# toks = CategoryMap(toks, sort=False)
= brain_bias[:, :, 0].squeeze(-1)
brain_bias = load_pickle(source_mimic/'code_desc.pkl')
lbs_des assert isinstance(lbs_des, dict)
len(toks), len(lbs))) # last two places has 'xxfake'
test_eq(brain.shape, (len(lbs)]) test_eq(brain_bias.shape, [
last two places in brain vocab has ['xxfake', 'xxfake']
The tokens which are there in the xml vocab but not in the brain:
= L(set(xml_vocab[0]).difference(set(brain_vocab[0])))
not_found_in_brain not_found_in_brain
(#20) ['cella','q2day','remiained','luteinizing','promiscuity','sharpio','calcijex','dissension','mhc','theses'...]
lambda : toks.index('cella'), contains='is not in list') test_fail(
The tokens which are in the brain but were not present in the xml vocab:
set(brain_vocab[0]).difference(xml_vocab[0])
set()
Thankfully, we have info
for all the labels in the xml vocab:
assert set(brain_vocab[1]).symmetric_difference(brain_vocab[1]) == set()
# test_shuffled(xml_vocab[1], mimic_vocab[1])
= _xml2brain(xml_vocab[0], brain_vocab[0])
toks_xml2brain, toks_notfnd
= set(toks_xml2brain).difference(set(toks_notfnd))
toks_found 0])[toks_notfnd], not_found_in_brain)
test_shuffled(array(xml_vocab[= np.random.choice(array(L(toks_found)), size=10)
some_xml_idxs = array(xml_vocab[0])[some_xml_idxs]
some_xml_toks = L(map(toks_xml2brain.get, some_xml_idxs))
corres_brain_idxs = array(toks)[corres_brain_idxs]
corres_brain_toks assert all_equal(some_xml_toks, corres_brain_toks)
= _xml2brain(xml_vocab[1], brain_vocab[1])
lbs_xml2brain, lbs_notfnd
= set(lbs_xml2brain).difference(set(lbs_notfnd))
lbs_found = np.random.choice(array(L(lbs_found)), size=10)
some_xml_idxs = array(xml_vocab[1])[some_xml_idxs]
some_xml_lbs = L(map(lbs_xml2brain.get, some_xml_idxs))
corres_brain_idxs = array(lbs)[corres_brain_idxs]
corres_brain_lbs assert all_equal(some_xml_lbs, corres_brain_lbs)
brainsplant
brainsplant (xml_vocab, brain_vocab, brain, brain_bias, device=None)
= brainsplant(xml_vocab, brain_vocab, brain, brain_bias)
xml_brain, xml_lbsbias, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain map(len))
test_eq(xml_brain.shape, xml_vocab.len(toks_notfnd), len(xml_vocab[1])))
test_eq(xml_brain[toks_notfnd], xml_brain.new_zeros(assert all_equal(array(xml_vocab[0])[toks_map.itemgot(0)], array(brain_vocab[0])[toks_map.itemgot(1)])
assert all_equal(array(xml_vocab[1])[lbs_map.itemgot(0)], array(brain_vocab[1])[lbs_map.itemgot(1)])
# tests to ensure `brainsplant` was successful
= '642.41'
lbl = '38.93'
lbl = '51.10'
lbl = '996.87'
lbl = brain_vocab[1].index(lbl)
lbl_idx_from_brn = L(brain[:, lbl_idx_from_brn].topk(k=20)).map(Self.cpu())
tok_vals_from_brn, top_toks_from_brn= xml_vocab[1].index(lbl)
lbl_idx_from_xml = L(xml_brain[:, lbl_idx_from_xml].topk(k=20)).map(Self.cpu())
tok_vals_from_xml, top_toks_from_xml
test_eq(lbs_xml2brain[lbl_idx_from_xml], lbl_idx_from_brn)
test_eq(tok_vals_from_brn, tok_vals_from_xml)0])[top_toks_from_brn], array(xml_vocab[0])[top_toks_from_xml])
test_eq(array(brain_vocab[
test_eq(brain_bias[lbl_idx_from_brn], xml_lbsbias[lbl_idx_from_xml])print(f"For the lbl {lbl} ({lbs_des.get(lbl)}), the top tokens that needs attention are:")
print('\n'.join(L(array(xml_vocab[0])[top_toks_from_xml], use_list=True).zipwith(L(tok_vals_from_xml.numpy(), use_list=True)).map(str).map(lambda o: "+ "+o)))
For the lbl 996.87 (Complications of transplanted intestine), the top tokens that needs attention are:
+ ('consultued', 0.25548762)
+ ('cip', 0.25548762)
+ ('parlor', 0.24661502)
+ ('transplantations', 0.18601614)
+ ('scaffoid', 0.18601614)
+ ('epineprine', 0.18601614)
+ ('culinary', 0.17232327)
+ ('coordinates', 0.1469037)
+ ('aminotransferases', 0.12153866)
+ ('hydronephroureter', 0.12153866)
+ ('27yom', 0.12153866)
+ ('27y', 0.103684604)
+ ('hardward', 0.090407245)
+ ('leukoreduction', 0.08014185)
+ ('venting', 0.07831942)
+ ('secrete', 0.07196123)
+ ('orthogonal', 0.07196123)
+ ('naac', 0.06891022)
+ ('mgso4', 0.0662555)
+ ('septecemia', 0.065286644)
= 'fibrillation'
tok = 'colpo'
tok = 'amiodarone'
tok = 'flagyl'
tok = 'nasalilid'
tok = 'hemetemesis'
tok = 'restitched'
tok = brain_vocab[0].index(tok)
tok_idx_from_brn = L(brain[tok_idx_from_brn].topk(k=20)).map(Self.cpu())
lbs_vals_from_brn, top_lbs_from_brn = xml_vocab[0].index(tok)
tok_idx_from_xml
test_eq(tok_idx_from_brn, toks_xml2brain[tok_idx_from_xml])= L(xml_brain[tok_idx_from_xml].topk(k=20)).map(Self.cpu())
lbs_vals_from_xml, top_lbs_from_xml
test_eq(lbs_vals_from_brn, lbs_vals_from_xml)try:
1])[top_lbs_from_brn], array(xml_vocab[1])[top_lbs_from_xml])
test_eq(array(brain_vocab[except AssertionError as e:
print(type(e).__name__, "due to instability in sorting (nothing to worry!)");
1])[top_lbs_from_brn], array(xml_vocab[1])[top_lbs_from_xml])
test_shuffled(array(brain_vocab[print('')
print(f"For the token {tok}, the top labels that needs attention are:")
print('\n'.join(L(mapt(lbs_des.get, array(xml_vocab[1])[top_lbs_from_xml])).zipwith(L(lbs_vals_from_xml.numpy(), use_list=True)).map(str).map(lambda o: "+ "+o)))
For the token restitched, the top labels that needs attention are:
+ ('Other operations on supporting structures of uterus', 0.29102018)
+ ('Other proctopexy', 0.29102018)
+ ('Other operations on cul-de-sac', 0.18601614)
+ (None, 0.07494824)
+ ('Intervertebral disc disorder with myelopathy, thoracic region', 0.055331517)
+ ('Excision of scapula, clavicle, and thorax [ribs and sternum] for graft', 0.04382947)
+ ('Other repair of omentum', 0.028067086)
+ ('Chronic lymphocytic thyroiditis', 0.01986737)
+ (None, 0.019236181)
+ ('Reclosure of postoperative disruption of abdominal wall', 0.016585195)
+ ('Other disorders of calcium metabolism', 0.009393147)
+ ('Pain in joint involving pelvic region and thigh', 0.008421187)
+ ('Exteriorization of small intestine', 0.00817792)
+ ('Fusion or refusion of 9 or more vertebrae', 0.00762466)
+ ('Kyphosis (acquired) (postural)', 0.0074228523)
+ ('Unspecified procedure as the cause of abnormal reaction of patient, or of later complication, without mention of misadventure at time of procedure', 0.0063889036)
+ ('Application or administration of adhesion barrier substance', 0.00610513)
+ ('Acute osteomyelitis involving other specified sites', 0.0054434645)
+ ('Body Mass Index less than 19, adult', 0.004719585)
+ ('Dorsal and dorsolumbar fusion, anterior technique', 0.0046444684)
= random.sample(toks_map.itemgot(0), 10)
some_toks = [c*6 for c in random.sample(range(10), 10)]
counts = random.sample(some_toks, 20, counts=counts)
some_toks # Counter(some_toks)
= L(mapt(toks_xml2brain.get, some_toks))
cors_toks_brn 0])[cors_toks_brn], array(xml_vocab[0])[some_toks])
test_eq(array(brain_vocab[print("some tokens (with repetitions):\n",'\n'.join(['-'+xml_vocab[0][t]for t in some_toks]))
some tokens (with repetitions):
-disorientated
-disorientated
-dmh
-ibp
-literacy
-abruptly
-faxed
-delsym
-literacy
-delsym
-literacy
-ibp
-literacy
-delsym
-abruptly
-caox3
-caox3
-caox3
-caox3
-literacy
= xml_brain[some_toks]
attn len(some_toks), xml_brain.shape[1]))
test_eq(attn.shape, (# semantics of attn
# for each token we can compute the attention each label deserves by pulling out all the columns for a label
for t, a in zip(some_toks,attn):
test_eq(xml_brain[t], a)# for each label we can compute the attention those tokens deserve by pulling out all rows for a label
for lbl in range(xml_brain.shape[1]):
test_eq(xml_brain[:, lbl][some_toks], attn[:, lbl])
0][t], l:=xml_vocab[1][lbl_idx], val.item(), lbs_des.get(l, 'NF')) for t,lbl_idx,val in zip(some_toks,attn.max(dim=1).indices.cpu(), attn.max(dim=1).values.cpu())],
pd.DataFrame([(xml_vocab[=['token', 'most_relevant_lbl', 'lbl_attn', 'description']).sort_values(by='lbl_attn', ascending=False) columns
token | most_relevant_lbl | lbl_attn | description | |
---|---|---|---|---|
7 | delsym | 344.2 | 0.096369 | Diplegia of upper limbs |
13 | delsym | 344.2 | 0.096369 | Diplegia of upper limbs |
9 | delsym | 344.2 | 0.096369 | Diplegia of upper limbs |
2 | dmh | 983.1 | 0.041843 | Toxic effect of acids |
0 | disorientated | 171.0 | 0.036627 | Malignant neoplasm of connective and other soft tissue of head, face, and neck |
1 | disorientated | 171.0 | 0.036627 | Malignant neoplasm of connective and other soft tissue of head, face, and neck |
18 | caox3 | 375.01 | 0.028963 | Acute dacryoadenitis |
17 | caox3 | 375.01 | 0.028963 | Acute dacryoadenitis |
16 | caox3 | 375.01 | 0.028963 | Acute dacryoadenitis |
15 | caox3 | 375.01 | 0.028963 | Acute dacryoadenitis |
12 | literacy | 449 | 0.018083 | Septic arterial embolism |
10 | literacy | 449 | 0.018083 | Septic arterial embolism |
8 | literacy | 449 | 0.018083 | Septic arterial embolism |
4 | literacy | 449 | 0.018083 | Septic arterial embolism |
19 | literacy | 449 | 0.018083 | Septic arterial embolism |
6 | faxed | 15.9 | 0.012289 | Other operations on extraocular muscles and tendons |
11 | ibp | 39.90 | 0.006865 | Insertion of non-drug-eluting peripheral vessel stent(s) |
3 | ibp | 39.90 | 0.006865 | Insertion of non-drug-eluting peripheral vessel stent(s) |
14 | abruptly | 315.8 | 0.006252 | Other specified delays in development |
5 | abruptly | 315.8 | 0.006252 | Other specified delays in development |
from xcube.layers import inattention
# define label inattention cutoff
= 5 k
= attn.clone().unsqueeze(0).permute(0,2,1).inattention(k=k).permute(0,2,1).squeeze(0).contiguous() # applying `inattention` across the lbs dim
top_lbs_attn len(some_toks), xml_brain.shape[1]))
test_eq(top_lbs_attn.shape, (
test_ne(attn, top_lbs_attn)=1), attn.argmax(dim=1))
test_eq(top_lbs_attn.argmax(dim= top_lbs_attn.sum(dim=0)
lbs_cf 1]])
test_eq(lbs_cf.shape, [top_lbs_attn.shape[= lbs_cf.nonzero().flatten().cpu()
idxs print(f"After looking at the tokens {[xml_vocab[0][t]for t in some_toks]}, I am confident about the following labels:")
:=xml_vocab[1][idx], val.item(), lbs_des.get(l, 'NF')) for idx,val in zip(idxs,lbs_cf[idxs])],
pd.DataFrame([(l=['lbl', 'lbl_cf', 'description']).sort_values(by='lbl_cf', ascending=False) columns
After looking at the tokens ['disorientated', 'disorientated', 'dmh', 'ibp', 'literacy', 'abruptly', 'faxed', 'delsym', 'literacy', 'delsym', 'literacy', 'ibp', 'literacy', 'delsym', 'abruptly', 'caox3', 'caox3', 'caox3', 'caox3', 'literacy'], I am confident about the following labels:
lbl | lbl_cf | description | |
---|---|---|---|
10 | 344.2 | 0.289106 | Diplegia of upper limbs |
22 | 367.1 | 0.289106 | Myopia |
36 | 706.1 | 0.158741 | Other acne |
35 | 691.8 | 0.145640 | Other atopic dermatitis and related conditions |
21 | 442.89 | 0.134519 | Aneurysm of other specified site |
50 | 374.89 | 0.115853 | Other disorders of eyelid |
49 | 375.01 | 0.115853 | Acute dacryoadenitis |
20 | 449 | 0.090417 | Septic arterial embolism |
11 | 438.13 | 0.087779 | NF |
23 | 423.1 | 0.080689 | Adhesive pericarditis |
13 | 304.71 | 0.080689 | Combinations of opioid type drug with any other drug dependence, continuous use |
18 | 315.9 | 0.079333 | Unspecified delay in development |
24 | 171.0 | 0.073254 | Malignant neoplasm of connective and other soft tissue of head, face, and neck |
29 | 259.9 | 0.072001 | Unspecified endocrine disorder |
28 | 701.2 | 0.072001 | Acquired acanthosis nigricans |
44 | 370.00 | 0.066347 | Corneal ulcer, unspecified |
34 | 814.00 | 0.057924 | Closed fracture of carpal bone, unspecified |
7 | E968.8 | 0.050906 | Assault by other specified means |
30 | 722.72 | 0.049080 | Intervertebral disc disorder with myelopathy, thoracic region |
17 | 444.21 | 0.047748 | Arterial embolism and thrombosis of upper extremity |
42 | 618.1 | 0.045409 | Uterine prolapse without mention of vaginal wall prolapse |
45 | 531.30 | 0.041843 | Acute gastric ulcer without mention of hemorrhage or perforation, without mention of obstruction |
41 | 54.73 | 0.041843 | Other repair of peritoneum |
37 | 983.1 | 0.041843 | Toxic effect of acids |
52 | 824.7 | 0.041843 | Trimalleolar fracture, open |
6 | 921.0 | 0.039063 | Black eye, NOS |
39 | 959.9 | 0.036281 | Other and unspecified injury to unspecified site |
16 | 290.3 | 0.032885 | Senile dementia with delirium |
43 | 784.92 | 0.032388 | NF |
15 | 444.89 | 0.028912 | Embolism and thrombosis of other artery |
32 | 39.90 | 0.013729 | Insertion of non-drug-eluting peripheral vessel stent(s) |
19 | 315.8 | 0.012504 | Other specified delays in development |
2 | 15.9 | 0.012289 | Other operations on extraocular muscles and tendons |
14 | 38.12 | 0.011341 | Endarterectomy of other vessels of head and neck |
9 | 68.16 | 0.010278 | Closed biopsy of uterus |
46 | 674.14 | 0.009737 | Disruption of cesarean wound, postpartum |
47 | 654.44 | 0.009737 | Other abnormalities in shape or position of gravid uterus and of neighboring structures, postpartum |
48 | 669.44 | 0.009737 | Other complications of obstetrical surgery and procedures, postpartum condition or complication |
51 | 205.90 | 0.009737 | Unspecified myeloid leukemia without mention of remission |
27 | 86.19 | 0.009737 | Other diagnostic procedures on skin and subcutaneous tissue |
53 | 021.8 | 0.009737 | Other specified tularemia |
40 | 362.07 | 0.009737 | Diabetic macular edema |
38 | 989.3 | 0.009737 | Toxic effect of organophosphate and carbamate |
26 | 17.7 | 0.009737 | NF |
25 | 706.9 | 0.009737 | Unspecified disease of sebaceous glands |
31 | 39.50 | 0.009523 | Angioplasty or atherectomy of other non-coronary vessel(s) |
5 | 421.0 | 0.008755 | Acute and subacute bacterial endocarditis |
0 | 348.5 | 0.007448 | Cerebral edema |
1 | 13.1 | 0.006795 | Intracapsular extraction of lens |
4 | 93.0 | 0.005886 | Diagnostic physical therapy |
3 | 21.2 | 0.005648 | Diagnostic procedures on nose |
33 | 36.05 | 0.005054 | NF |
8 | 440.1 | 0.004669 | Atherosclerosis of renal artery |
12 | 37.94 | 0.004538 | Implantation or replacement of automatic cardioverter/defibrillator, total system [AICD] |
= torch.load(join_path_file('lin_lambdarank_full', source_l2r, ext='.pth'), map_location=default_device())
l2r_wgts if 'model' in l2r_wgts: l2r_wgts = l2r_wgts['model']
Need to match the wgts in xml and brain:
def brainsplant_diffntble(xml_vocab, brain_vocab, l2r_wgts, device=None):
= 'toks lbs'.split()
toks_lbs = master_bar(range(2))
mb for i in mb:
globals().update(dict(zip((toks_lbs[i]+'_xml2brain', toks_lbs[i]+'_notfnd'), (_xml2brain(xml_vocab[i], brain_vocab[i], parent_bar=mb)))))
= f"Finished Loop {i}"
mb.write = L((xml_idx, brn_idx) for xml_idx, brn_idx in toks_xml2brain.items() if brn_idx is not np.inf)
toks_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in lbs_xml2brain.items() if brn_idx is not np.inf)
lbs_map = torch.zeros(len(xml_vocab[0]), 200).to(default_device() if device is None else device)
tf_xml = torch.zeros(len(xml_vocab[0]), 1).to(default_device() if device is None else device)
tb_xml = torch.zeros(len(xml_vocab[1]), 200).to(default_device() if device is None else device)
lf_xml = torch.zeros(len(xml_vocab[1]), 1).to(default_device() if device is None else device)
lb_xml = list(l2r_wgts.values())
tf_l2r, tb_l2r, lf_l2r, lb_l2r 0)] = tf_l2r[toks_map.itemgot(1)].clone()
tf_xml[toks_map.itemgot(0)] = tb_l2r[toks_map.itemgot(1)].clone()
tb_xml[toks_map.itemgot(0)] = lf_l2r[lbs_map.itemgot(1)].clone()
lf_xml[lbs_map.itemgot(0)] = lb_l2r[lbs_map.itemgot(1)].clone()
lb_xml[lbs_map.itemgot(# import pdb; pdb.set_trace()
= {k: xml_val for k, xml_val in zip(l2r_wgts.keys(), (tf_xml, tb_xml, lf_xml, lb_xml))}
xml_wgts = nn.ModuleDict({k.split('.')[0]: nn.Embedding(*v.size()) for k,v in xml_wgts.items()}).to(default_device() if device is None else device)
mod_dict
mod_dict.load_state_dict(xml_wgts)return mod_dict, toks_map, lbs_map
= brainsplant_diffntble(xml_vocab, brain_vocab, l2r_wgts)
mod_dict, toks_map, lbs_map assert isinstance(mod_dict, nn.Module)
assert nn.Module in mod_dict.__class__.__mro__
'token_factors'].weight.data[toks_map.itemgot(0)], l2r_wgts['token_factors.weight'][toks_map.itemgot(1)])
test_eq(mod_dict['token_bias'].weight.data[toks_map.itemgot(0)], l2r_wgts['token_bias.weight'][toks_map.itemgot(1)])
test_eq(mod_dict['label_factors'].weight.data[lbs_map.itemgot(0)], l2r_wgts['label_factors.weight'][lbs_map.itemgot(1)])
test_eq(mod_dict['label_bias'].weight.data[lbs_map.itemgot(0)], l2r_wgts['label_bias.weight'][lbs_map.itemgot(1)]) test_eq(mod_dict[
mod_dict
ModuleDict(
(token_factors): Embedding(57376, 200)
(token_bias): Embedding(57376, 1)
(label_factors): Embedding(8922, 200)
(label_bias): Embedding(8922, 1)
)
= ['996.87', '51.10', '38.93']
some_lbs
for lbl in some_lbs:
print(f"{lbl}: {lbs_des.get(lbl, 'NF')}")
996.87: Complications of transplanted intestine
51.10: Endoscopic retrograde cholangiopancreatography [ERCP]
38.93: Venous catheterization, not elsewhere classified
= tensor(mapt(xml_vocab[1].index, some_lbs)).to(default_device()) lbs_idx
= torch.randint(0, len(xml_vocab[0]), (72,)).to(default_device())
toks_idx print("-"+'\n-'.join(array(xml_vocab[0])[toks_idx.cpu()].tolist()))
-influx
-latissimus
-equinovarus
-deteriorates
-aap
-mvh
-135
-incipient
-rhubarb
-nizhoni
-trancutaneous
-indicaton
-subset
-largyngeal
-lemonade
-debulk
-aerations
-l34
-perserverates
-trendelenberg
-kettr
-meningitic
-bored
-hashimoto
-mountains
-wit
-asts
-ellicits
-pax
-adb
-alcholism
-violinist
-301b
-subpopulation
-intraorally
-98o2
-agreesive
-monilla
-jig
-paroxysmalatrial
-10pts
-knees
-conventionally
-soonest
-recap
-rediscuss
-spontanous
-pulmary
-repletement
-450x12
-symetrically
-fdi
-pshx
-svco2
-topimax
-2100cc
-conceal
-nauasea
-decontamination
-administrator
-fraction
-tachyarrythmia
-oversee
-dabigutran
-reiterated
-aftetr
-bues
-symettric
-powerful
-depocyte
-hyperextension
-hepsc
= mod_dict['token_factors'](toks_idx) @ mod_dict['label_factors'](lbs_idx).T + mod_dict['token_bias'](toks_idx) + mod_dict['label_bias'](lbs_idx).T
apprx_brain apprx_brain.shape
torch.Size([72, 3])
These are the tokens as ranked by the pretrained L2R model (which is essentially an approximation of the actual brain):
0])[toks_idx[apprx_brain.argsort(dim=0, descending=True)].cpu()], columns=L(zip(some_lbs, mapt(lbs_des.get, some_lbs))).map(': '.join)) pd.DataFrame(array(xml_vocab[
996.87: Complications of transplanted intestine | 51.10: Endoscopic retrograde cholangiopancreatography [ERCP] | 38.93: Venous catheterization, not elsewhere classified | |
---|---|---|---|
0 | fraction | wit | fraction |
1 | knees | fraction | subpopulation |
2 | subpopulation | administrator | knees |
3 | wit | subset | subset |
4 | administrator | knees | pshx |
... | ... | ... | ... |
67 | paroxysmalatrial | mvh | powerful |
68 | indicaton | ellicits | indicaton |
69 | rhubarb | indicaton | perserverates |
70 | depocyte | aftetr | rhubarb |
71 | aftetr | conceal | monilla |
72 rows × 3 columns
Just to compare: This is how an actual brain would rank those tokens:
# array(xml_vocab[0])[xml_brain[:, lbl_idx].topk(k=20, dim=0).indices.cpu()]
0])[toks_idx[xml_brain[:, lbs_idx][toks_idx].argsort(descending=True, dim=0)].cpu()], columns=L(zip(some_lbs, mapt(lbs_des.get, some_lbs))).map(': '.join)) pd.DataFrame(array(xml_vocab[
996.87: Complications of transplanted intestine | 51.10: Endoscopic retrograde cholangiopancreatography [ERCP] | 38.93: Venous catheterization, not elsewhere classified | |
---|---|---|---|
0 | fraction | wit | knees |
1 | knees | administrator | svco2 |
2 | hyperextension | pshx | meningitic |
3 | meningitic | hashimoto | fraction |
4 | 301b | reiterated | subset |
... | ... | ... | ... |
67 | latissimus | topimax | pshx |
68 | monilla | conceal | equinovarus |
69 | dabigutran | aftetr | debulk |
70 | trendelenberg | symettric | oversee |
71 | deteriorates | depocyte | l34 |
72 rows × 3 columns
Base Learner
for NLP
load_collab_keys
load_collab_keys (model, wgts:dict)
Load only collab wgts
(i_weight
and i_bias
) in model
, keeping the rest as is
Type | Details | |
---|---|---|
model | Model architecture | |
wgts | dict | Model weights |
Returns | tuple |
= awd_lstm_clas_config.copy()
config 'n_hid': 10, 'emb_sz': 5})
config.update({# tst = get_text_classifier(AWD_LSTM, 100, 3, config=config)
= get_xmltext_classifier(AWD_LSTM, 100, 3, config=config)
tst = tst.state_dict().copy()
old_sd = re.compile(".*attn.*")
r for key in old_sd if 'attn' in key], list(filter(r.match, old_sd)))
test_eq([key print("\n".join(list(filter(r.match, old_sd))))
1.pay_attn.lbs.weight
1.boost_attn.lin.weight
1.boost_attn.lin.bias
import copy
= copy.deepcopy(tst.state_dict())
old_sd
load_collab_keys(tst, new_wgts)# <TODO: Deb> fix the following tests later
# test_ne(old_sd['1.attn.lbs_weight.weight'], tst.state_dict()['1.attn.lbs_weight.weight'])
# test_eq(tst.state_dict()['1.pay_attn.lbs_weight.weight'], new_wgts['i_weight.weight'])
# test_ne(old_sd['1.attn.lbs_weight_dp.emb.weight'], tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'])
# test_eq(tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'], new_wgts['i_weight.weight'])
<All keys matched successfully>
TextLearner
TextLearner (dls:fastai.data.core.DataLoaders, model, alpha:float=2.0, beta:float=1.0, moms:tuple=(0.8, 0.7, 0.8), loss_func:callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:callable=<function trainable_params>, cbs:Callback|MutableSequence|None=None, metrics:callable|MutableSequence|None=None, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, default_cbs:bool=True)
Basic class for a Learner
in NLP.
Type | Default | Details | |
---|---|---|---|
dls | DataLoaders | Text DataLoaders |
|
model | A standard PyTorch model | ||
alpha | float | 2.0 | Param for RNNRegularizer |
beta | float | 1.0 | Param for RNNRegularizer |
moms | tuple | (0.8, 0.7, 0.8) | Momentum for Cosine Annealing Scheduler |
loss_func | callable | None | None | Loss function. Defaults to dls loss |
opt_func | Optimizer | OptimWrapper | Adam | Optimization function for training |
lr | float | slice | 0.001 | Default learning rate |
splitter | callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |
cbs | Callback | MutableSequence | None | None | Callback s to add to Learner |
metrics | callable | MutableSequence | None | None | Metric s to calculate on validation set |
path | str | Path | None | None | Parent directory to save, load, and export models. Defaults to dls path |
model_dir | str | Path | models | Subdirectory to save and load models |
wd | float | int | None | None | Default weight decay |
wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |
train_bn | bool | True | Train frozen normalization layers |
default_cbs | bool | True | Include default Callback s |
LMLearner.save_decoder
LMLearner.save_decoder (file:str)
Save the decoder to file
in the model directory
Type | Details | |
---|---|---|
file | str | Filename for Decoder |
Adds a ModelResetter
and an RNNRegularizer
with alpha
and beta
to the callbacks, the rest is the same as Learner
init.
This Learner
adds functionality to the base class:
Learner
convenience functions
xmltext_classifier_learner
xmltext_classifier_learner (dls, arch, seq_len=72, config=None, backwards=False, pretrained=True, collab=False, drop_mult=0.5, n_out=None, lin_ftrs=None, ps=None, max_len=1440, y_range=None, splitter=None, running_decoder=True, loss_func:Optional[<built- infunctioncallable>]=None, opt_func:Union[fas tai.optimizer.Optimizer,fastai.optimizer.Opti mWrapper]=<function Adam>, lr:Union[float,slice]=0.001, cbs:Union[fastai .callback.core.Callback,collections.abc.Mutab leSequence,NoneType]=None, metrics:Union[<built-infunctioncallable>,coll ections.abc.MutableSequence,NoneType]=None, path:Union[str,pathlib.Path,NoneType]=None, model_dir:Union[str,pathlib.Path]='models', wd:Union[float,int,NoneType]=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
Create a Learner
with a text classifier from dls
and arch
.
Type | Default | Details | |
---|---|---|---|
dls | DataLoaders | DataLoaders containing fastai or PyTorch DataLoader s |
|
arch | |||
seq_len | int | 72 | |
config | NoneType | None | |
backwards | bool | False | |
pretrained | bool | True | |
collab | bool | False | |
drop_mult | float | 0.5 | |
n_out | NoneType | None | |
lin_ftrs | NoneType | None | |
ps | NoneType | None | |
max_len | int | 1440 | |
y_range | NoneType | None | |
splitter | callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |
running_decoder | bool | True | |
loss_func | callable | None | None | Loss function. Defaults to dls loss |
opt_func | Optimizer | OptimWrapper | Adam | Optimization function for training |
lr | float | slice | 0.001 | Default learning rate |
cbs | Callback | MutableSequence | None | None | Callback s to add to Learner |
metrics | callable | MutableSequence | None | None | Metric s to calculate on validation set |
path | str | Path | None | None | Parent directory to save, load, and export models. Defaults to dls path |
model_dir | str | Path | models | Subdirectory to save and load models |
wd | float | int | None | None | Default weight decay |
wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |
train_bn | bool | True | Train frozen normalization layers |
moms | tuple | (0.95, 0.85, 0.95) | Default momentum for schedulers |
default_cbs | bool | True | Include default Callback s |