def _add2(x): return x+2
= Lambda(_add2)
tst = torch.randn(10,20)
x +2)
test_eq(tst(x), x= pickle.loads(pickle.dumps(tst))
tst2 +2) test_eq(tst2(x), x
Layers
Basic manipulations and resizing
One can easily create a beautiful layer with minimum boilerplate using fastai utilities. We will show a few simple examples here. For details and extensive illustrations please refer to decorated fastai layers.
An easy way to create a pytorch layer for a simple func
Lambda
Lambda (func)
An easy way to create a pytorch layer for a simple func
PartialLambda
PartialLambda (func)
Layer that applies partial(func, **kwargs)
def test_func(a,b=2): return a+b
= PartialLambda(test_func, b=5)
tst +5) test_eq(tst(x), x
Linear
ElemWiseLin
ElemWiseLin (dim0, dim1, add_bias=False, **kwargs)
Same as nn.Module
, but no need for subclasses to call super().__init__
= 10, 1271, 400
bs, dim0, dim1 = ElemWiseLin(dim0, dim1)
tst
test_eq(tst.lin.weight.shape, (dim0, dim1))= torch.randn(bs, dim0, dim1)
x test_eq(tst(x).shape, (bs, dim0, dim1))
BatchNorm Layers
LinBnFlatDrop
LinBnFlatDrop (n_in, n_out, bn=True, p=0.0, act=None, lin_first=False)
Module grouping BatchNorm1dFlat
, Dropout
and Linear
layers
LinBnDrop
LinBnDrop (n_in, n_out=None, bn=True, ln=True, p=0.0, act=None, lin_first=False, ndim=1)
Module grouping BatchNorm1d
, Dropout
and Linear
layers
LinBnDrop
is just like fastai’s LinBnDrop with an extra modality ln
which provides the option of skipping the linear layer. That is, BatchNorm
or the Linear
layer is skipped if bn=False
or ln=False
, as is the dropout if p=0
. Optionally, you can add an activation for after the linear layer with act.
= LinBnDrop(10, 20)
tst = list(tst.children())
mods assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)
The LinBnDrop
layer is not going to add an activation (even if provided) if ln
is False
but raise an error if not ln and ln_first
:
= LinBnDrop(10, 20, ln=False, p=0.02, act=nn.ReLU(inplace=True))
tst = list(tst.children())
mods assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
lambda : LinBnDrop(10, 20, ln=False, lin_first=True), contains='AssertionError') test_fail(
Embeddings
Embedding
Embedding (ni, nf, std=0.01, **kwargs)
Embedding layer with truncated normal initialization
Attention Layers for Extreme Multi-Label Classification
_linear_attention
_linear_attention (sentc:torch.Tensor, based_on:Union[torch.nn.modules.sp arse.Embedding,fastai.torch_core.Module])
Type | Details | |
---|---|---|
sentc | Tensor | Sentence typically (bs, bptt, nh) |
based_on | nn.Embedding | Module | xcube’s Embedding(n_lbs, nh) layer holding the label embeddings or a full fledged model |
_planted_attention
_planted_attention (sentc:torch.Tensor, brain:torch.Tensor)
Type | Details | |
---|---|---|
sentc | Tensor | Sentence typically (bs, bptt) containing the vocab idxs that goes inside the encoder |
brain | Tensor | label specific attn wgts for each token in vocab, typically of shape (vocab_sz, n_lbs) |
_diffntble_attention
_diffntble_attention (inp:torch.Tensor, based_on:torch.nn.modules.container.ModuleDict)
Type | Details | |
---|---|---|
inp | Tensor | Sentence typically (bs, bptt) containing the vocab idxs that goes inside the encoder |
based_on | nn.ModuleDict | dictionary of pretrained nn.Embedding from l2r model |
Linear_Attention
Linear_Attention (based_on:fastai.torch_core.Module)
= 16, 72, 100, 10
bs, bptt, nh, n_lbs = Embedding(n_lbs, nh)
tst_lbs = Linear_Attention(tst_lbs)
tst_Lin_Attn = Lambda(tst_Lin_Attn)
attn_layer = torch.randn(bs, bptt, nh)
sentc
test_eq(tst_Lin_Attn(sentc).shape , (bs, bptt, n_lbs))@ tst_lbs.weight.transpose(0,1))
test_eqs(attn_layer(sentc), tst_Lin_Attn(sentc), sentc
= pickle.loads(pickle.dumps(attn_layer))
attn_layer2 @ tst_lbs.weight.transpose(0,1)) test_eqs(attn_layer2(sentc), sentc
Planted_Attention
Planted_Attention (brain:torch.Tensor)
= 16, 72, 100, 10
bs, bptt, vocab_sz, n_lbs = torch.zeros((bs, bptt)).random_(vocab_sz)
inp = torch.randn(vocab_sz, n_lbs)
brain = Planted_Attention(brain)
tst_planted_Attn = Lambda(tst_planted_Attn)
attn_layer = brain[inp.long()]
attn
test_eq(attn.shape, (bs, bptt, n_lbs))
test_eqs(attn, tst_planted_Attn(inp), attn_layer(inp))# test_eq(brain[sentc[8].long()][:, 4], attn[8, :, 4]) # looking at the attn wgts of the 8th sentence and 4th label
Diff_Planted_Attention
Diff_Planted_Attention (based_on:fastai.torch_core.Module)
lincomb
lincomb (t, wgts=None)
returns the linear combination of the dim1 of a 3d tensor of t
based on wgts
(if wgts
is None
just adds the rows)
= torch.randn(16, 72, 100)
t = t.new_ones(t.size(0), 1, t.size(1))
wgts
test_eq(torch.bmm(wgts, t), lincomb(t))= t.new_empty(t.size(0), 15, t.size(1)).random_(10)
rand_wgts # test_eq(lincomb(t, wgts=rand_wgts), torch.bmm(rand_wgts, t))
= PartialLambda(lincomb, wgts=rand_wgts)
tst_LinComb test_eq(tst_LinComb(t), torch.bmm(rand_wgts, t))
topkmax
topkmax (k=None, dim=1)
returns softmax of the 1th dim of 3d tensor x after zeroing out values in x smaller than k
th largest. If k is None
behaves like x.softmax(dim=dim). Intuitively,
topkmaxhedges more compared to
F.softmax``
split_sort
split_sort (t, sp_dim, sort_dim, sp_sz=500, **kwargs)
= torch.randn(16, 106, 819)
t = split_sort(t, sp_dim=1, sort_dim=-1, sp_sz=14)
s_t =-1).values, s_t) test_eq(t.sort(dim
inattention
inattention (k=None, sort_dim=0, sp_dim=0)
returns self
after zeroing out values smaller than k
th largest in dimension dim
. If k is None
behaves like returns self.
TODO: DEB - Make it work for other dims - Hyperparmam schedule the k in topkmax (start with high gradually decrease)
= torch.randn((2, 7, 3))
x =1))
test_eq(x.topkmax() , F.softmax(x, dim# test_fail(topkmax, args=(x, ), kwargs=dict(dim=-1)) # NotImplemented
=dict(dim=-1)) # NotImplemented
test_fail(x.topkmax, kwargs=2, sort_dim=-1),
test_eq(x.inattention(k< x.sort(dim=-1, descending=True).values[:, :, 2].unsqueeze(dim=-1), 0, x)) torch.where(x
= torch.randn((8820,) )
x = torch.where(x < x.sort(dim=0, descending=True).values[2].unsqueeze(dim=0), 0, x)
x_inattn = x.inattention(k=2, sort_dim=0)
x_inattn1 test_eq(x_inattn, x_inattn1)
XMLAttention
XMLAttention (n_lbs, emb_sz, embed_p=0.0)
Compute label specific attention weights for each token in a sequence
# testing linear attention
= torch.zeros(bs, bptt).random_(100)
inp = torch.randn(bs, bptt, nh)
sentc = sentc.new_empty(sentc.size()[:-1]).random_(2).bool()
mask 0., 1.]))
test_eq(mask.unique(), tensor([= XMLAttention(n_lbs, nh)
xml_attn = xml_attn(inp, sentc, mask)
attn, tok_wgts, lbs_cf
test_eq(attn.shape, (bs, n_lbs, nh))= xml_attn.lbs
tst_lbs = Linear_Attention(tst_lbs)
tst_Lin_Attn = Lambda(tst_Lin_Attn)
lin_attn_layer = F.softmax(lin_attn_layer(sentc), dim=1) # topkmax(attn_layer(sentc), dim=1)
attn_wgts None], 0).transpose(1,2), sentc))
test_eq(attn, torch.bmm(attn_wgts.masked_fill(mask[:, :,
# testing planted attention followed by inattention
assert xml_attn.attn.func.f is _linear_attention
= torch.zeros((bs, bptt)).random_(vocab_sz)
inp = torch.randn(vocab_sz, n_lbs)
brain = Lambda(Planted_Attention(brain))
plant_attn_layer # xml_attn.attn = plant_attn_layer
setattr(xml_attn, 'attn', plant_attn_layer)
assert xml_attn.attn.func.f is _planted_attention
= xml_attn(inp, sentc, mask)
attn, tok_wgts, lbs_cf
test_eqs(tok_wgts, None], 0).inattention(k=15, sort_dim=1),
plant_attn_layer(inp).masked_fill(mask[:,:,long()].masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)
brain[inp.
)
test_eq(attn,
lincomb(sentc, =brain[inp.long()].masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1).transpose(1,2)
wgts
) )
Test masking works:
for attn_layer in (lin_attn_layer, plant_attn_layer):
setattr(xml_attn, 'attn', attn_layer)
= torch.zeros(bs, bptt).random_(100)
inp = torch.randn(bs, bptt, nh)
sentc = sentc.masked_fill(mask[:, :, None], 0)
sentc assert sentc[mask].sum().item() == 0
= xml_attn(inp, sentc, mask)
attn, tok_wgts, lbs_cf assert sentc[mask].sum().item() == 0
= F.softmax(attn_layer(sentc), dim=1) if attn_layer is lin_attn_layer else attn_layer(inp).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)# topkmax(attn_layer(sentc), dim=1)
attn_wgts 1,2), sentc)) test_eq(attn, torch.bmm(attn_wgts.transpose(