Layers

Basic manipulations and resizing

One can easily create a beautiful layer with minimum boilerplate using fastai utilities. We will show a few simple examples here. For details and extensive illustrations please refer to decorated fastai layers.

An easy way to create a pytorch layer for a simple func


source

Lambda

 Lambda (func)

An easy way to create a pytorch layer for a simple func

def _add2(x): return x+2
tst = Lambda(_add2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
tst2 = pickle.loads(pickle.dumps(tst))
test_eq(tst2(x), x+2)

source

PartialLambda

 PartialLambda (func)

Layer that applies partial(func, **kwargs)

def test_func(a,b=2): return a+b
tst = PartialLambda(test_func, b=5)
test_eq(tst(x), x+5)

Linear


source

ElemWiseLin

 ElemWiseLin (dim0, dim1, add_bias=False, **kwargs)

Same as nn.Module, but no need for subclasses to call super().__init__

bs, dim0, dim1 = 10, 1271, 400
tst = ElemWiseLin(dim0, dim1)
test_eq(tst.lin.weight.shape, (dim0, dim1))
x = torch.randn(bs, dim0, dim1)
test_eq(tst(x).shape, (bs, dim0, dim1))

BatchNorm Layers


source

LinBnFlatDrop

 LinBnFlatDrop (n_in, n_out, bn=True, p=0.0, act=None, lin_first=False)

Module grouping BatchNorm1dFlat, Dropout and Linear layers


source

LinBnDrop

 LinBnDrop (n_in, n_out=None, bn=True, ln=True, p=0.0, act=None,
            lin_first=False, ndim=1)

Module grouping BatchNorm1d, Dropout and Linear layers

LinBnDrop is just like fastai’s LinBnDrop with an extra modality ln which provides the option of skipping the linear layer. That is, BatchNorm or the Linear layer is skipped if bn=False or ln=False, as is the dropout if p=0. Optionally, you can add an activation for after the linear layer with act.

tst = LinBnDrop(10, 20)
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)

The LinBnDrop layer is not going to add an activation (even if provided) if ln is False but raise an error if not ln and ln_first:

tst = LinBnDrop(10, 20, ln=False, p=0.02, act=nn.ReLU(inplace=True))
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
test_fail(lambda : LinBnDrop(10, 20, ln=False, lin_first=True), contains='AssertionError')

Embeddings


source

Embedding

 Embedding (ni, nf, std=0.01, **kwargs)

Embedding layer with truncated normal initialization

Attention Layers for Extreme Multi-Label Classification


source

_linear_attention

 _linear_attention (sentc:torch.Tensor, based_on:Union[torch.nn.modules.sp
                    arse.Embedding,fastai.torch_core.Module])
Type Details
sentc Tensor Sentence typically (bs, bptt, nh)
based_on nn.Embedding | Module xcube’s Embedding(n_lbs, nh) layer holding the label embeddings or a full fledged model

source

_planted_attention

 _planted_attention (sentc:torch.Tensor, brain:torch.Tensor)
Type Details
sentc Tensor Sentence typically (bs, bptt) containing the vocab idxs that goes inside the encoder
brain Tensor label specific attn wgts for each token in vocab, typically of shape (vocab_sz, n_lbs)

source

_diffntble_attention

 _diffntble_attention (inp:torch.Tensor,
                       based_on:torch.nn.modules.container.ModuleDict)
Type Details
inp Tensor Sentence typically (bs, bptt) containing the vocab idxs that goes inside the encoder
based_on nn.ModuleDict dictionary of pretrained nn.Embedding from l2r model

source

Linear_Attention

 Linear_Attention (based_on:fastai.torch_core.Module)
bs, bptt, nh, n_lbs = 16, 72, 100, 10
tst_lbs = Embedding(n_lbs, nh)
tst_Lin_Attn = Linear_Attention(tst_lbs)
attn_layer = Lambda(tst_Lin_Attn)
sentc = torch.randn(bs, bptt, nh)
test_eq(tst_Lin_Attn(sentc).shape , (bs, bptt, n_lbs))
test_eqs(attn_layer(sentc), tst_Lin_Attn(sentc), sentc @ tst_lbs.weight.transpose(0,1))

attn_layer2 = pickle.loads(pickle.dumps(attn_layer))
test_eqs(attn_layer2(sentc), sentc @ tst_lbs.weight.transpose(0,1))

source

Planted_Attention

 Planted_Attention (brain:torch.Tensor)
bs, bptt, vocab_sz, n_lbs = 16, 72, 100, 10
inp = torch.zeros((bs, bptt)).random_(vocab_sz)
brain = torch.randn(vocab_sz, n_lbs)
tst_planted_Attn = Planted_Attention(brain)
attn_layer = Lambda(tst_planted_Attn)
attn = brain[inp.long()]
test_eq(attn.shape, (bs, bptt, n_lbs))
test_eqs(attn, tst_planted_Attn(inp), attn_layer(inp))
# test_eq(brain[sentc[8].long()][:, 4], attn[8, :, 4]) # looking at the attn wgts of the 8th sentence and 4th label

source

Diff_Planted_Attention

 Diff_Planted_Attention (based_on:fastai.torch_core.Module)

source

lincomb

 lincomb (t, wgts=None)

returns the linear combination of the dim1 of a 3d tensor of t based on wgts (if wgts is None just adds the rows)

t = torch.randn(16, 72, 100)
wgts = t.new_ones(t.size(0), 1, t.size(1))
test_eq(torch.bmm(wgts, t), lincomb(t))
rand_wgts = t.new_empty(t.size(0), 15, t.size(1)).random_(10)
# test_eq(lincomb(t, wgts=rand_wgts), torch.bmm(rand_wgts, t))
tst_LinComb = PartialLambda(lincomb, wgts=rand_wgts)
test_eq(tst_LinComb(t), torch.bmm(rand_wgts, t))

topkmax

 topkmax (k=None, dim=1)

returns softmax of the 1th dim of 3d tensor x after zeroing out values in x smaller than kth largest. If k is None behaves like x.softmax(dim=dim). Intuitively,topkmaxhedges more compared toF.softmax``


source

split_sort

 split_sort (t, sp_dim, sort_dim, sp_sz=500, **kwargs)
t = torch.randn(16, 106, 819)
s_t = split_sort(t, sp_dim=1, sort_dim=-1, sp_sz=14)
test_eq(t.sort(dim=-1).values, s_t)

inattention

 inattention (k=None, sort_dim=0, sp_dim=0)

returns self after zeroing out values smaller than kth largest in dimension dim. If k is None behaves like returns self.

TODO: DEB - Make it work for other dims - Hyperparmam schedule the k in topkmax (start with high gradually decrease)

x = torch.randn((2, 7, 3))
test_eq(x.topkmax() , F.softmax(x, dim=1))
# test_fail(topkmax, args=(x, ), kwargs=dict(dim=-1)) # NotImplemented
test_fail(x.topkmax, kwargs=dict(dim=-1)) # NotImplemented
test_eq(x.inattention(k=2, sort_dim=-1), 
        torch.where(x < x.sort(dim=-1, descending=True).values[:, :, 2].unsqueeze(dim=-1), 0, x))
x = torch.randn((8820,) )
x_inattn = torch.where(x < x.sort(dim=0, descending=True).values[2].unsqueeze(dim=0), 0, x)
x_inattn1 = x.inattention(k=2, sort_dim=0)
test_eq(x_inattn, x_inattn1)

source

XMLAttention

 XMLAttention (n_lbs, emb_sz, embed_p=0.0)

Compute label specific attention weights for each token in a sequence

# testing linear attention
inp = torch.zeros(bs, bptt).random_(100)
sentc = torch.randn(bs, bptt, nh)
mask = sentc.new_empty(sentc.size()[:-1]).random_(2).bool()
test_eq(mask.unique(), tensor([0., 1.]))
xml_attn = XMLAttention(n_lbs, nh)
attn, tok_wgts, lbs_cf = xml_attn(inp, sentc, mask)
test_eq(attn.shape, (bs, n_lbs, nh))
tst_lbs = xml_attn.lbs
tst_Lin_Attn = Linear_Attention(tst_lbs)
lin_attn_layer = Lambda(tst_Lin_Attn)
attn_wgts = F.softmax(lin_attn_layer(sentc), dim=1) # topkmax(attn_layer(sentc), dim=1)
test_eq(attn, torch.bmm(attn_wgts.masked_fill(mask[:, :, None], 0).transpose(1,2), sentc))

# testing planted attention followed by inattention
assert xml_attn.attn.func.f is _linear_attention
inp = torch.zeros((bs, bptt)).random_(vocab_sz)
brain = torch.randn(vocab_sz, n_lbs)
plant_attn_layer = Lambda(Planted_Attention(brain))
# xml_attn.attn = plant_attn_layer
setattr(xml_attn, 'attn', plant_attn_layer)
assert xml_attn.attn.func.f is _planted_attention
attn, tok_wgts, lbs_cf = xml_attn(inp, sentc, mask)
test_eqs(tok_wgts, 
         plant_attn_layer(inp).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1), 
         brain[inp.long()].masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)
        )
test_eq(attn, 
        lincomb(sentc, 
                wgts=brain[inp.long()].masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1).transpose(1,2)
               )
       )

Test masking works:

for attn_layer in (lin_attn_layer, plant_attn_layer):
    setattr(xml_attn, 'attn', attn_layer)
    inp = torch.zeros(bs, bptt).random_(100)
    sentc = torch.randn(bs, bptt, nh)
    sentc = sentc.masked_fill(mask[:, :, None], 0)
    assert sentc[mask].sum().item() == 0
    attn, tok_wgts, lbs_cf = xml_attn(inp, sentc, mask)
    assert sentc[mask].sum().item() == 0
    attn_wgts = F.softmax(attn_layer(sentc), dim=1) if attn_layer is lin_attn_layer else attn_layer(inp).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)# topkmax(attn_layer(sentc), dim=1)
    test_eq(attn, torch.bmm(attn_wgts.transpose(1,2), sentc))