# %%debug
= LabelAttentionClassifier(400, 1271)
attn_clas 'n_hidden', 'n_lbs'), (400, 1271))
test_eq(getattrs(attn_clas, = torch.zeros(16, 72*20).random_(10), torch.randn(16, 72*20, 400), torch.randint(2, size=(16, 72*20))
inps, outs, mask *_ = attn_clas((inps, outs, mask))
x, 16, 1271)) test_eq(x.shape, (
Core XML Text Modules
The models provided here are variations of the ones provided by fastai with modifications tailored for XML.
Basic Models
SequentialRNN
SequentialRNN (*args)
A sequential pytorch module that passes the reset call to its children.
Classification Models
The SentenceEncoder
below is the fastai’s source code. Copied here for understanding its components and chaning it to AttentiveSentenceEncoder
:
SentenceEncoder
SentenceEncoder (bptt, module, pad_idx=1, max_len=None)
Create an encoder over module
that can process a full sentence.
This module expects the inputs padded with most of the padding first, with the sequence beginning at a round multiple of bptt (and the rest of the padding at the end). Use pad_input_chunk
to get your data in a suitable format.
AttentiveSentenceEncoder
AttentiveSentenceEncoder (bptt, module, decoder, pad_idx=1, max_len=None, running_decoder=True)
Create an encoder over module
that can process a full sentence.
masked_concat_pool
masked_concat_pool (output, mask, bptt)
Pool MultiBatchEncoder
outputs into one vector [last_hidden, max_pool, avg_pool]
XPoolingLinearClassifier
XPoolingLinearClassifier (dims, ps, bptt, y_range=None)
Same as nn.Module
, but no need for subclasses to call super().__init__
Note that XPoolingLinearClassifier
is exactly same as fastai’s PoolingLinearClassifier
except that we do not do the feature compression from 1200 to 50 linear features.
Note: Also try XPoolingLinearClassifier
w/o dropouts and batch normalization (Verify this, but as far as what I found it does not work well as compared to /w batch normalization)
LabelAttentionClassifier
LabelAttentionClassifier (n_hidden, n_lbs, y_range=None)
Same as nn.Module
, but no need for subclasses to call super().__init__
TODOS: Deb - ~Find out what happens with respect to RNN Regularizer callback after LabelAttentionClassifier returns a tuple of 3. (Check the learner cbs and follow the RNNcallback
)~ - ~Check if we are losing anything by ignoring the mask in LabelAttentionClassifier
. That is should we be ignoring the masked tokens while computing atten wgts.~
- Change the label bias initial distribution from uniform to the one we leanerd seperately. - ~Implement Treacher Forcing~
get_xmltext_classifier
get_xmltext_classifier (arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1.0, pad_idx=1, max_len=1440, y_range=None)
Create a text classifier from arch
and its config
, maybe pretrained
get_xmltext_classifier2
get_xmltext_classifier2 (arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1.0, pad_idx=1, max_len=1440, y_range=None, running_decoder=True)
Create a text classifier from arch
and its config
, maybe pretrained