import torch.nn as nn
import pytorch_lightning as pl
class BCESkipLoss(nn.Module):
"""
BCE loss that skips target entries with all zeros
"""
def __init__(self):
super(BCESkipLoss, self).__init__()
self.bce = nn.BCELoss()
def forward(self, out, tgt):
# mask contains 1s on timestamps with any events in target
# and 0s on timestamps without them
mask = tgt.max(2).values.unsqueeze(2).repeat(1, 1, tgt.shape[2]).detach()
# mask out timestamps without events in output
loss = self.bce(out * mask, tgt)
return loss
class Session2Session(pl.LightningModule):
def __init__(self, n_events: int, max_len: int = 45):
super(Session2Session, self).__init__()
# layers
d_model = n_events + 2 # + <sos> + <eos>
nhead = 2
nlayers = 2
dim_feedforward = 128
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
self.encoder = nn.TransformerEncoder(encoder_layer, nlayers)
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
self.decoder = nn.TransformerDecoder(decoder_layer, nlayers)
self.tgt_mask = generate_square_subsequent_mask(max_len, attend_same=False)
# training settings
self.criterion = BCESkipLoss()
def forward(self, src, tgt):
memory = self.encoder(src)
out = self.decoder(tgt, memory, self.tgt_mask.type_as(tgt))
out = torch.sigmoid(out)
return out
def training_step(self, batch, batch_nb):
src, tgt = [x.permute(2, 0, 1) for x in batch]
tgt_hat = self(src, tgt)
loss = self.criterion(tgt_hat, tgt)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_nb):
src, tgt = [x.permute(2, 0, 1) for x in batch]
tgt_hat = self(src, tgt)
loss = self.criterion(tgt_hat, tgt)
self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_nb):
src, tgt = [x.permute(2, 0, 1) for x in batch]
tgt_hat = self(src, tgt)
loss = self.criterion(tgt_hat, tgt)
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def generate_square_subsequent_mask(sz, attend_same=True):
if attend_same:
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
else:
# default square mask attends to the element with the same index
# which is not suitable for e.g. attention to target
# thus we use shifted by 1 version of mask
mask = torch.triu(torch.ones(sz + 1, sz)).transpose(0, 1)[:, 1:]
# however we can't set a row of all -infs, so we hack the first row
mask[0, 0] = 1
mask = mask.float().masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, float(0.0))
return mask