Learner for the XML Text application:

All the functions necessary to build Learner suitable for transfer learning in XML text classification.

The most important function of this module is xmltext_classifier_learner. This will help you define a Learner using a pretrained Language Model for the encoder and a pretrained Learning-to-Rank-Model for the decoder. (Tutorial: Coming Soon!). This module is inspired from fastai’s TextLearner based on the paper ULMFit.

Loading label embeddings from a pretrained colab model


source

match_collab

 match_collab (old_wgts:dict, collab_vocab:dict, lbs_vocab:list)

Convert the label embedding in old_wgts to go from old_vocab in colab to lbs_vocab

Type Details
old_wgts dict Embedding weights of the colab model
collab_vocab dict Vocabulary of token and label used for colab pre-training
lbs_vocab list Current labels vocabulary
Returns dict
wgts = {'u_weight.weight': torch.randn(3,5), 
        'i_weight.weight': torch.randn(4,5),
        'u_bias.weight'  : torch.randn(3,1),
        'i_bias.weight'  : torch.randn(4,1)}
collab_vocab = {'token': ['#na#', 'sun', 'moon', 'earth', 'mars'],
                'label': ['#na#', 'a', 'c', 'b']}
lbs_vocab = ['a', 'b', 'c']
new_wgts, missing = match_collab(wgts.copy(), collab_vocab, lbs_vocab)
test_eq(missing, 0)
test_close(wgts['u_weight.weight'], new_wgts['u_weight.weight'])
test_close(wgts['u_bias.weight'], new_wgts['u_bias.weight'])
with ExceptionExpected(ex=AssertionError, regex="close"):
    test_close(wgts['i_weight.weight'][1:], new_wgts['i_weight.weight'])
    test_close(wgts['i_bias.weight'][1:], new_wgts['i_bias.weight'])
old_w, new_w = wgts['i_weight.weight'], new_wgts['i_weight.weight']
old_b, new_b = wgts['i_bias.weight'], new_wgts['i_bias.weight']
for (old_k,old_v), (new_k, new_v) in zip(wgts.items(), new_wgts.items()): 
    if old_k.startswith('u'): test_eq(old_v.size(), new_v.size())
    else: test_ne(old_v.size(), new_v.size());
    # print(f"old: {old_k} = {old_v.size()}, new: {new_k} = {new_v.size()}")
test_eq(new_w[0], old_w[1]); test_eq(new_b[0], old_b[1])
test_eq(new_w[1], old_w[3]); test_eq(new_b[1], old_b[3])
test_eq(new_w[2], old_w[2]); test_eq(new_b[2], old_b[2])
test_shuffled(list(old_b[1:].squeeze().numpy()), list(new_b.squeeze().numpy()))
test_eq(torch.sort(old_b[1:], dim=0)[0], torch.sort(new_b, dim=0)[0])
test_eq(torch.sort(old_w[1:], dim=0)[0], torch.sort(new_w, dim=0)[0])

Loading Pretrained Information Gain as Attention

from xcube.l2r.all import *
source_mimic = untar_xxx(XURLs.MIMIC3)
xml_vocab = load_pickle(source_mimic/'mimic3-9k_clas_full_vocab.pkl')
xml_vocab = L(xml_vocab).map(listify)
source_l2r = untar_xxx(XURLs.MIMIC3_L2R)
boot_path = join_path_file('mimic3-9k_tok_lbl_info', source_l2r, ext='.pkl')
bias_path = join_path_file('p_L', source_l2r, ext='.pkl')
l2r_bootstrap = torch.load(boot_path, map_location=default_device())
brain_bias = torch.load(bias_path, map_location=default_device())
*brain_vocab, brain = mapt(l2r_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
brain_vocab = L(brain_vocab).map(listify)
toks, lbs = brain_vocab
print(f"last two places in brain vocab has {toks[-2:]}")
# toks = CategoryMap(toks, sort=False)
brain_bias = brain_bias[:, :, 0].squeeze(-1)
lbs_des = load_pickle(source_mimic/'code_desc.pkl')
assert isinstance(lbs_des, dict)
test_eq(brain.shape, (len(toks), len(lbs))) # last two places has 'xxfake'
test_eq(brain_bias.shape, [len(lbs)])
last two places in brain vocab has ['xxfake', 'xxfake']

The tokens which are there in the xml vocab but not in the brain:

not_found_in_brain = L(set(xml_vocab[0]).difference(set(brain_vocab[0])))
not_found_in_brain
(#20) ['cella','q2day','remiained','luteinizing','promiscuity','sharpio','calcijex','dissension','mhc','theses'...]
test_fail(lambda : toks.index('cella'), contains='is not in list')

The tokens which are in the brain but were not present in the xml vocab:

set(brain_vocab[0]).difference(xml_vocab[0])
set()

Thankfully, we have info for all the labels in the xml vocab:

assert set(brain_vocab[1]).symmetric_difference(brain_vocab[1]) == set()
# test_shuffled(xml_vocab[1], mimic_vocab[1])
toks_xml2brain, toks_notfnd = _xml2brain(xml_vocab[0], brain_vocab[0])

toks_found = set(toks_xml2brain).difference(set(toks_notfnd))
test_shuffled(array(xml_vocab[0])[toks_notfnd], not_found_in_brain)
some_xml_idxs = np.random.choice(array(L(toks_found)), size=10)
some_xml_toks = array(xml_vocab[0])[some_xml_idxs]
corres_brain_idxs = L(map(toks_xml2brain.get, some_xml_idxs))
corres_brain_toks = array(toks)[corres_brain_idxs]
assert all_equal(some_xml_toks, corres_brain_toks)
100.00% [57376/57376 00:19<00:00]
lbs_xml2brain, lbs_notfnd = _xml2brain(xml_vocab[1], brain_vocab[1])

lbs_found = set(lbs_xml2brain).difference(set(lbs_notfnd))
some_xml_idxs = np.random.choice(array(L(lbs_found)), size=10)
some_xml_lbs = array(xml_vocab[1])[some_xml_idxs]
corres_brain_idxs = L(map(lbs_xml2brain.get, some_xml_idxs))
corres_brain_lbs = array(lbs)[corres_brain_idxs]
assert all_equal(some_xml_lbs, corres_brain_lbs)
100.00% [8922/8922 00:00<00:00]

source

brainsplant

 brainsplant (xml_vocab, brain_vocab, brain, brain_bias, device=None)
xml_brain, xml_lbsbias, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain = brainsplant(xml_vocab, brain_vocab, brain, brain_bias)
test_eq(xml_brain.shape, xml_vocab.map(len))
test_eq(xml_brain[toks_notfnd], xml_brain.new_zeros(len(toks_notfnd), len(xml_vocab[1])))
assert all_equal(array(xml_vocab[0])[toks_map.itemgot(0)], array(brain_vocab[0])[toks_map.itemgot(1)])
assert all_equal(array(xml_vocab[1])[lbs_map.itemgot(0)], array(brain_vocab[1])[lbs_map.itemgot(1)])
# tests to ensure `brainsplant` was successful 
lbl = '642.41'
lbl = '38.93'
lbl = '51.10'
lbl = '996.87'
lbl_idx_from_brn = brain_vocab[1].index(lbl)
tok_vals_from_brn, top_toks_from_brn= L(brain[:, lbl_idx_from_brn].topk(k=20)).map(Self.cpu())
lbl_idx_from_xml = xml_vocab[1].index(lbl)
tok_vals_from_xml, top_toks_from_xml = L(xml_brain[:, lbl_idx_from_xml].topk(k=20)).map(Self.cpu())
test_eq(lbs_xml2brain[lbl_idx_from_xml], lbl_idx_from_brn)
test_eq(tok_vals_from_brn, tok_vals_from_xml)
test_eq(array(brain_vocab[0])[top_toks_from_brn], array(xml_vocab[0])[top_toks_from_xml])
test_eq(brain_bias[lbl_idx_from_brn], xml_lbsbias[lbl_idx_from_xml])
print(f"For the lbl {lbl} ({lbs_des.get(lbl)}), the top tokens that needs attention are:")
print('\n'.join(L(array(xml_vocab[0])[top_toks_from_xml], use_list=True).zipwith(L(tok_vals_from_xml.numpy(), use_list=True)).map(str).map(lambda o: "+ "+o)))
For the lbl 996.87 (Complications of transplanted intestine), the top tokens that needs attention are:
+ ('consultued', 0.25548762)
+ ('cip', 0.25548762)
+ ('parlor', 0.24661502)
+ ('transplantations', 0.18601614)
+ ('scaffoid', 0.18601614)
+ ('epineprine', 0.18601614)
+ ('culinary', 0.17232327)
+ ('coordinates', 0.1469037)
+ ('aminotransferases', 0.12153866)
+ ('hydronephroureter', 0.12153866)
+ ('27yom', 0.12153866)
+ ('27y', 0.103684604)
+ ('hardward', 0.090407245)
+ ('leukoreduction', 0.08014185)
+ ('venting', 0.07831942)
+ ('secrete', 0.07196123)
+ ('orthogonal', 0.07196123)
+ ('naac', 0.06891022)
+ ('mgso4', 0.0662555)
+ ('septecemia', 0.065286644)
tok = 'fibrillation'
tok = 'colpo'
tok = 'amiodarone'
tok = 'flagyl'
tok = 'nasalilid'
tok = 'hemetemesis'
tok = 'restitched'
tok_idx_from_brn = brain_vocab[0].index(tok)
lbs_vals_from_brn, top_lbs_from_brn = L(brain[tok_idx_from_brn].topk(k=20)).map(Self.cpu())
tok_idx_from_xml = xml_vocab[0].index(tok)
test_eq(tok_idx_from_brn, toks_xml2brain[tok_idx_from_xml])
lbs_vals_from_xml, top_lbs_from_xml = L(xml_brain[tok_idx_from_xml].topk(k=20)).map(Self.cpu())
test_eq(lbs_vals_from_brn, lbs_vals_from_xml)
try: 
    test_eq(array(brain_vocab[1])[top_lbs_from_brn], array(xml_vocab[1])[top_lbs_from_xml])
except AssertionError as e: 
    print(type(e).__name__, "due to instability in sorting (nothing to worry!)");
    test_shuffled(array(brain_vocab[1])[top_lbs_from_brn], array(xml_vocab[1])[top_lbs_from_xml])
print('')
print(f"For the token {tok}, the top labels that needs attention are:")
print('\n'.join(L(mapt(lbs_des.get, array(xml_vocab[1])[top_lbs_from_xml])).zipwith(L(lbs_vals_from_xml.numpy(), use_list=True)).map(str).map(lambda o: "+ "+o)))

For the token restitched, the top labels that needs attention are:
+ ('Other operations on supporting structures of uterus', 0.29102018)
+ ('Other proctopexy', 0.29102018)
+ ('Other operations on cul-de-sac', 0.18601614)
+ (None, 0.07494824)
+ ('Intervertebral disc disorder with myelopathy, thoracic region', 0.055331517)
+ ('Excision of scapula, clavicle, and thorax [ribs and sternum] for graft', 0.04382947)
+ ('Other repair of omentum', 0.028067086)
+ ('Chronic lymphocytic thyroiditis', 0.01986737)
+ (None, 0.019236181)
+ ('Reclosure of postoperative disruption of abdominal wall', 0.016585195)
+ ('Other disorders of calcium metabolism', 0.009393147)
+ ('Pain in joint involving pelvic region and thigh', 0.008421187)
+ ('Exteriorization of small intestine', 0.00817792)
+ ('Fusion or refusion of 9 or more vertebrae', 0.00762466)
+ ('Kyphosis (acquired) (postural)', 0.0074228523)
+ ('Unspecified procedure as the cause of abnormal reaction of patient, or of later complication, without mention of misadventure at time of procedure', 0.0063889036)
+ ('Application or administration of adhesion barrier substance', 0.00610513)
+ ('Acute osteomyelitis involving other specified sites', 0.0054434645)
+ ('Body Mass Index less than 19, adult', 0.004719585)
+ ('Dorsal and dorsolumbar fusion, anterior technique', 0.0046444684)
some_toks = random.sample(toks_map.itemgot(0), 10)
counts = [c*6 for c in random.sample(range(10), 10)]
some_toks = random.sample(some_toks, 20, counts=counts)
# Counter(some_toks)
cors_toks_brn = L(mapt(toks_xml2brain.get, some_toks))
test_eq(array(brain_vocab[0])[cors_toks_brn], array(xml_vocab[0])[some_toks])
print("some tokens (with repetitions):\n",'\n'.join(['-'+xml_vocab[0][t]for t in some_toks]))
some tokens (with repetitions):
 -disorientated
-disorientated
-dmh
-ibp
-literacy
-abruptly
-faxed
-delsym
-literacy
-delsym
-literacy
-ibp
-literacy
-delsym
-abruptly
-caox3
-caox3
-caox3
-caox3
-literacy
attn = xml_brain[some_toks]
test_eq(attn.shape, (len(some_toks), xml_brain.shape[1]))
# semantics of attn
# for each token we can compute the attention each label deserves by pulling out all the columns for a label
for t, a in zip(some_toks,attn):
    test_eq(xml_brain[t], a)
# for each label we can compute the attention those tokens deserve by pulling out all rows for a label
for lbl in range(xml_brain.shape[1]):
    test_eq(xml_brain[:, lbl][some_toks], attn[:, lbl])
pd.DataFrame([(xml_vocab[0][t], l:=xml_vocab[1][lbl_idx], val.item(), lbs_des.get(l, 'NF')) for t,lbl_idx,val in zip(some_toks,attn.max(dim=1).indices.cpu(), attn.max(dim=1).values.cpu())],
            columns=['token', 'most_relevant_lbl', 'lbl_attn', 'description']).sort_values(by='lbl_attn', ascending=False)
token most_relevant_lbl lbl_attn description
7 delsym 344.2 0.096369 Diplegia of upper limbs
13 delsym 344.2 0.096369 Diplegia of upper limbs
9 delsym 344.2 0.096369 Diplegia of upper limbs
2 dmh 983.1 0.041843 Toxic effect of acids
0 disorientated 171.0 0.036627 Malignant neoplasm of connective and other soft tissue of head, face, and neck
1 disorientated 171.0 0.036627 Malignant neoplasm of connective and other soft tissue of head, face, and neck
18 caox3 375.01 0.028963 Acute dacryoadenitis
17 caox3 375.01 0.028963 Acute dacryoadenitis
16 caox3 375.01 0.028963 Acute dacryoadenitis
15 caox3 375.01 0.028963 Acute dacryoadenitis
12 literacy 449 0.018083 Septic arterial embolism
10 literacy 449 0.018083 Septic arterial embolism
8 literacy 449 0.018083 Septic arterial embolism
4 literacy 449 0.018083 Septic arterial embolism
19 literacy 449 0.018083 Septic arterial embolism
6 faxed 15.9 0.012289 Other operations on extraocular muscles and tendons
11 ibp 39.90 0.006865 Insertion of non-drug-eluting peripheral vessel stent(s)
3 ibp 39.90 0.006865 Insertion of non-drug-eluting peripheral vessel stent(s)
14 abruptly 315.8 0.006252 Other specified delays in development
5 abruptly 315.8 0.006252 Other specified delays in development
from xcube.layers import inattention
# define label inattention cutoff
k = 5
top_lbs_attn = attn.clone().unsqueeze(0).permute(0,2,1).inattention(k=k).permute(0,2,1).squeeze(0).contiguous() # applying `inattention` across the lbs dim
test_eq(top_lbs_attn.shape, (len(some_toks), xml_brain.shape[1]))
test_ne(attn, top_lbs_attn)
test_eq(top_lbs_attn.argmax(dim=1), attn.argmax(dim=1))
lbs_cf = top_lbs_attn.sum(dim=0)
test_eq(lbs_cf.shape, [top_lbs_attn.shape[1]])
idxs = lbs_cf.nonzero().flatten().cpu()
print(f"After looking at the tokens {[xml_vocab[0][t]for t in some_toks]}, I am confident about the following labels:")
pd.DataFrame([(l:=xml_vocab[1][idx], val.item(), lbs_des.get(l, 'NF')) for idx,val in zip(idxs,lbs_cf[idxs])],
            columns=['lbl', 'lbl_cf', 'description']).sort_values(by='lbl_cf', ascending=False)
After looking at the tokens ['disorientated', 'disorientated', 'dmh', 'ibp', 'literacy', 'abruptly', 'faxed', 'delsym', 'literacy', 'delsym', 'literacy', 'ibp', 'literacy', 'delsym', 'abruptly', 'caox3', 'caox3', 'caox3', 'caox3', 'literacy'], I am confident about the following labels:
lbl lbl_cf description
10 344.2 0.289106 Diplegia of upper limbs
22 367.1 0.289106 Myopia
36 706.1 0.158741 Other acne
35 691.8 0.145640 Other atopic dermatitis and related conditions
21 442.89 0.134519 Aneurysm of other specified site
50 374.89 0.115853 Other disorders of eyelid
49 375.01 0.115853 Acute dacryoadenitis
20 449 0.090417 Septic arterial embolism
11 438.13 0.087779 NF
23 423.1 0.080689 Adhesive pericarditis
13 304.71 0.080689 Combinations of opioid type drug with any other drug dependence, continuous use
18 315.9 0.079333 Unspecified delay in development
24 171.0 0.073254 Malignant neoplasm of connective and other soft tissue of head, face, and neck
29 259.9 0.072001 Unspecified endocrine disorder
28 701.2 0.072001 Acquired acanthosis nigricans
44 370.00 0.066347 Corneal ulcer, unspecified
34 814.00 0.057924 Closed fracture of carpal bone, unspecified
7 E968.8 0.050906 Assault by other specified means
30 722.72 0.049080 Intervertebral disc disorder with myelopathy, thoracic region
17 444.21 0.047748 Arterial embolism and thrombosis of upper extremity
42 618.1 0.045409 Uterine prolapse without mention of vaginal wall prolapse
45 531.30 0.041843 Acute gastric ulcer without mention of hemorrhage or perforation, without mention of obstruction
41 54.73 0.041843 Other repair of peritoneum
37 983.1 0.041843 Toxic effect of acids
52 824.7 0.041843 Trimalleolar fracture, open
6 921.0 0.039063 Black eye, NOS
39 959.9 0.036281 Other and unspecified injury to unspecified site
16 290.3 0.032885 Senile dementia with delirium
43 784.92 0.032388 NF
15 444.89 0.028912 Embolism and thrombosis of other artery
32 39.90 0.013729 Insertion of non-drug-eluting peripheral vessel stent(s)
19 315.8 0.012504 Other specified delays in development
2 15.9 0.012289 Other operations on extraocular muscles and tendons
14 38.12 0.011341 Endarterectomy of other vessels of head and neck
9 68.16 0.010278 Closed biopsy of uterus
46 674.14 0.009737 Disruption of cesarean wound, postpartum
47 654.44 0.009737 Other abnormalities in shape or position of gravid uterus and of neighboring structures, postpartum
48 669.44 0.009737 Other complications of obstetrical surgery and procedures, postpartum condition or complication
51 205.90 0.009737 Unspecified myeloid leukemia without mention of remission
27 86.19 0.009737 Other diagnostic procedures on skin and subcutaneous tissue
53 021.8 0.009737 Other specified tularemia
40 362.07 0.009737 Diabetic macular edema
38 989.3 0.009737 Toxic effect of organophosphate and carbamate
26 17.7 0.009737 NF
25 706.9 0.009737 Unspecified disease of sebaceous glands
31 39.50 0.009523 Angioplasty or atherectomy of other non-coronary vessel(s)
5 421.0 0.008755 Acute and subacute bacterial endocarditis
0 348.5 0.007448 Cerebral edema
1 13.1 0.006795 Intracapsular extraction of lens
4 93.0 0.005886 Diagnostic physical therapy
3 21.2 0.005648 Diagnostic procedures on nose
33 36.05 0.005054 NF
8 440.1 0.004669 Atherosclerosis of renal artery
12 37.94 0.004538 Implantation or replacement of automatic cardioverter/defibrillator, total system [AICD]
l2r_wgts = torch.load(join_path_file('lin_lambdarank_full', source_l2r, ext='.pth'), map_location=default_device())
if 'model' in l2r_wgts: l2r_wgts = l2r_wgts['model']

Need to match the wgts in xml and brain:

def brainsplant_diffntble(xml_vocab, brain_vocab, l2r_wgts, device=None):
    toks_lbs = 'toks lbs'.split()
    mb = master_bar(range(2))
    for i in mb:
        globals().update(dict(zip((toks_lbs[i]+'_xml2brain', toks_lbs[i]+'_notfnd'), (_xml2brain(xml_vocab[i], brain_vocab[i], parent_bar=mb)))))
        mb.write = f"Finished Loop {i}" 
    toks_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in toks_xml2brain.items() if brn_idx is not np.inf) 
    lbs_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in lbs_xml2brain.items() if brn_idx is not np.inf) 
    tf_xml = torch.zeros(len(xml_vocab[0]), 200).to(default_device() if device is None else device) 
    tb_xml = torch.zeros(len(xml_vocab[0]), 1).to(default_device() if device is None else device) 
    lf_xml = torch.zeros(len(xml_vocab[1]), 200).to(default_device() if device is None else device) 
    lb_xml = torch.zeros(len(xml_vocab[1]), 1).to(default_device() if device is None else device) 
    tf_l2r, tb_l2r, lf_l2r, lb_l2r = list(l2r_wgts.values())
    tf_xml[toks_map.itemgot(0)] = tf_l2r[toks_map.itemgot(1)].clone()
    tb_xml[toks_map.itemgot(0)] = tb_l2r[toks_map.itemgot(1)].clone()
    lf_xml[lbs_map.itemgot(0)] = lf_l2r[lbs_map.itemgot(1)].clone()
    lb_xml[lbs_map.itemgot(0)] = lb_l2r[lbs_map.itemgot(1)].clone()
    # import pdb; pdb.set_trace()
    xml_wgts = {k: xml_val for k, xml_val in zip(l2r_wgts.keys(), (tf_xml, tb_xml, lf_xml, lb_xml))}
    mod_dict = nn.ModuleDict({k.split('.')[0]: nn.Embedding(*v.size()) for k,v in xml_wgts.items()}).to(default_device() if device is None else device) 
    mod_dict.load_state_dict(xml_wgts)
    return mod_dict, toks_map, lbs_map
mod_dict, toks_map, lbs_map = brainsplant_diffntble(xml_vocab, brain_vocab, l2r_wgts)
assert isinstance(mod_dict, nn.Module)
assert nn.Module in mod_dict.__class__.__mro__ 

test_eq(mod_dict['token_factors'].weight.data[toks_map.itemgot(0)], l2r_wgts['token_factors.weight'][toks_map.itemgot(1)])
test_eq(mod_dict['token_bias'].weight.data[toks_map.itemgot(0)], l2r_wgts['token_bias.weight'][toks_map.itemgot(1)])
test_eq(mod_dict['label_factors'].weight.data[lbs_map.itemgot(0)], l2r_wgts['label_factors.weight'][lbs_map.itemgot(1)])
test_eq(mod_dict['label_bias'].weight.data[lbs_map.itemgot(0)], l2r_wgts['label_bias.weight'][lbs_map.itemgot(1)])
mod_dict
ModuleDict(
  (token_factors): Embedding(57376, 200)
  (token_bias): Embedding(57376, 1)
  (label_factors): Embedding(8922, 200)
  (label_bias): Embedding(8922, 1)
)
some_lbs = ['996.87', '51.10', '38.93']

for lbl in some_lbs:
    print(f"{lbl}: {lbs_des.get(lbl, 'NF')}")
996.87: Complications of transplanted intestine
51.10: Endoscopic retrograde cholangiopancreatography [ERCP]
38.93: Venous catheterization, not elsewhere classified
lbs_idx = tensor(mapt(xml_vocab[1].index, some_lbs)).to(default_device())
toks_idx = torch.randint(0, len(xml_vocab[0]), (72,)).to(default_device())
print("-"+'\n-'.join(array(xml_vocab[0])[toks_idx.cpu()].tolist()))
-influx
-latissimus
-equinovarus
-deteriorates
-aap
-mvh
-135
-incipient
-rhubarb
-nizhoni
-trancutaneous
-indicaton
-subset
-largyngeal
-lemonade
-debulk
-aerations
-l34
-perserverates
-trendelenberg
-kettr
-meningitic
-bored
-hashimoto
-mountains
-wit
-asts
-ellicits
-pax
-adb
-alcholism
-violinist
-301b
-subpopulation
-intraorally
-98o2
-agreesive
-monilla
-jig
-paroxysmalatrial
-10pts
-knees
-conventionally
-soonest
-recap
-rediscuss
-spontanous
-pulmary
-repletement
-450x12
-symetrically
-fdi
-pshx
-svco2
-topimax
-2100cc
-conceal
-nauasea
-decontamination
-administrator
-fraction
-tachyarrythmia
-oversee
-dabigutran
-reiterated
-aftetr
-bues
-symettric
-powerful
-depocyte
-hyperextension
-hepsc
apprx_brain = mod_dict['token_factors'](toks_idx) @ mod_dict['label_factors'](lbs_idx).T + mod_dict['token_bias'](toks_idx) + mod_dict['label_bias'](lbs_idx).T
apprx_brain.shape
torch.Size([72, 3])

These are the tokens as ranked by the pretrained L2R model (which is essentially an approximation of the actual brain):

pd.DataFrame(array(xml_vocab[0])[toks_idx[apprx_brain.argsort(dim=0, descending=True)].cpu()], columns=L(zip(some_lbs, mapt(lbs_des.get, some_lbs))).map(': '.join))
996.87: Complications of transplanted intestine 51.10: Endoscopic retrograde cholangiopancreatography [ERCP] 38.93: Venous catheterization, not elsewhere classified
0 fraction wit fraction
1 knees fraction subpopulation
2 subpopulation administrator knees
3 wit subset subset
4 administrator knees pshx
... ... ... ...
67 paroxysmalatrial mvh powerful
68 indicaton ellicits indicaton
69 rhubarb indicaton perserverates
70 depocyte aftetr rhubarb
71 aftetr conceal monilla

72 rows × 3 columns

Just to compare: This is how an actual brain would rank those tokens:

# array(xml_vocab[0])[xml_brain[:, lbl_idx].topk(k=20, dim=0).indices.cpu()]
pd.DataFrame(array(xml_vocab[0])[toks_idx[xml_brain[:, lbs_idx][toks_idx].argsort(descending=True, dim=0)].cpu()], columns=L(zip(some_lbs, mapt(lbs_des.get, some_lbs))).map(': '.join))
996.87: Complications of transplanted intestine 51.10: Endoscopic retrograde cholangiopancreatography [ERCP] 38.93: Venous catheterization, not elsewhere classified
0 fraction wit knees
1 knees administrator svco2
2 hyperextension pshx meningitic
3 meningitic hashimoto fraction
4 301b reiterated subset
... ... ... ...
67 latissimus topimax pshx
68 monilla conceal equinovarus
69 dabigutran aftetr debulk
70 trendelenberg symettric oversee
71 deteriorates depocyte l34

72 rows × 3 columns

Base Learner for NLP


source

load_collab_keys

 load_collab_keys (model, wgts:dict)

Load only collab wgts (i_weight and i_bias) in model, keeping the rest as is

Type Details
model Model architecture
wgts dict Model weights
Returns tuple
config = awd_lstm_clas_config.copy()
config.update({'n_hid': 10, 'emb_sz': 5})
# tst = get_text_classifier(AWD_LSTM, 100, 3, config=config)
tst = get_xmltext_classifier(AWD_LSTM, 100, 3, config=config)
old_sd = tst.state_dict().copy()
r = re.compile(".*attn.*")
test_eq([key for key in old_sd if 'attn' in key], list(filter(r.match, old_sd)))
print("\n".join(list(filter(r.match, old_sd))))
1.pay_attn.lbs.weight
1.boost_attn.lin.weight
1.boost_attn.lin.bias
import copy
old_sd = copy.deepcopy(tst.state_dict())
load_collab_keys(tst, new_wgts)
# <TODO: Deb> fix the following tests later
# test_ne(old_sd['1.attn.lbs_weight.weight'], tst.state_dict()['1.attn.lbs_weight.weight'])
# test_eq(tst.state_dict()['1.pay_attn.lbs_weight.weight'], new_wgts['i_weight.weight'])
# test_ne(old_sd['1.attn.lbs_weight_dp.emb.weight'], tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'])
# test_eq(tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'], new_wgts['i_weight.weight'])
<All keys matched successfully>

source

TextLearner

 TextLearner (dls:fastai.data.core.DataLoaders, model, alpha:float=2.0,
              beta:float=1.0, moms:tuple=(0.8, 0.7, 0.8),
              loss_func:callable|None=None,
              opt_func:Optimizer|OptimWrapper=<function Adam>,
              lr:float|slice=0.001, splitter:callable=<function
              trainable_params>, cbs:Callback|MutableSequence|None=None,
              metrics:callable|MutableSequence|None=None,
              path:str|Path|None=None, model_dir:str|Path='models',
              wd:float|int|None=None, wd_bn_bias:bool=False,
              train_bn:bool=True, default_cbs:bool=True)

Basic class for a Learner in NLP.

Type Default Details
dls DataLoaders Text DataLoaders
model A standard PyTorch model
alpha float 2.0 Param for RNNRegularizer
beta float 1.0 Param for RNNRegularizer
moms tuple (0.8, 0.7, 0.8) Momentum for Cosine Annealing Scheduler
loss_func callable | None None Loss function. Defaults to dls loss
opt_func Optimizer | OptimWrapper Adam Optimization function for training
lr float | slice 0.001 Default learning rate
splitter callable trainable_params Split model into parameter groups. Defaults to one parameter group
cbs Callback | MutableSequence | None None Callbacks to add to Learner
metrics callable | MutableSequence | None None Metrics to calculate on validation set
path str | Path | None None Parent directory to save, load, and export models. Defaults to dls path
model_dir str | Path models Subdirectory to save and load models
wd float | int | None None Default weight decay
wd_bn_bias bool False Apply weight decay to normalization and bias parameters
train_bn bool True Train frozen normalization layers
default_cbs bool True Include default Callbacks

source

LMLearner.save_decoder

 LMLearner.save_decoder (file:str)

Save the decoder to file in the model directory

Type Details
file str Filename for Decoder

Adds a ModelResetter and an RNNRegularizer with alpha and beta to the callbacks, the rest is the same as Learner init.

This Learner adds functionality to the base class:

Learner convenience functions


source

xmltext_classifier_learner

 xmltext_classifier_learner (dls, arch, seq_len=72, config=None,
                             backwards=False, pretrained=True,
                             collab=False, drop_mult=0.5, n_out=None,
                             lin_ftrs=None, ps=None, max_len=1440,
                             y_range=None, splitter=None,
                             running_decoder=True,
                             loss_func:Optional[<built-
                             infunctioncallable>]=None, opt_func:Union[fas
                             tai.optimizer.Optimizer,fastai.optimizer.Opti
                             mWrapper]=<function Adam>,
                             lr:Union[float,slice]=0.001, cbs:Union[fastai
                             .callback.core.Callback,collections.abc.Mutab
                             leSequence,NoneType]=None,
                             metrics:Union[<built-infunctioncallable>,coll
                             ections.abc.MutableSequence,NoneType]=None,
                             path:Union[str,pathlib.Path,NoneType]=None,
                             model_dir:Union[str,pathlib.Path]='models',
                             wd:Union[float,int,NoneType]=None,
                             wd_bn_bias:bool=False, train_bn:bool=True,
                             moms:tuple=(0.95, 0.85, 0.95),
                             default_cbs:bool=True)

Create a Learner with a text classifier from dls and arch.

Type Default Details
dls DataLoaders DataLoaders containing fastai or PyTorch DataLoaders
arch
seq_len int 72
config NoneType None
backwards bool False
pretrained bool True
collab bool False
drop_mult float 0.5
n_out NoneType None
lin_ftrs NoneType None
ps NoneType None
max_len int 1440
y_range NoneType None
splitter callable trainable_params Split model into parameter groups. Defaults to one parameter group
running_decoder bool True
loss_func callable | None None Loss function. Defaults to dls loss
opt_func Optimizer | OptimWrapper Adam Optimization function for training
lr float | slice 0.001 Default learning rate
cbs Callback | MutableSequence | None None Callbacks to add to Learner
metrics callable | MutableSequence | None None Metrics to calculate on validation set
path str | Path | None None Parent directory to save, load, and export models. Defaults to dls path
model_dir str | Path models Subdirectory to save and load models
wd float | int | None None Default weight decay
wd_bn_bias bool False Apply weight decay to normalization and bias parameters
train_bn bool True Train frozen normalization layers
moms tuple (0.95, 0.85, 0.95) Default momentum for schedulers
default_cbs bool True Include default Callbacks