import torch.nn as nn import pytorch_lightning as pl class BCESkipLoss(nn.Module): """ BCE loss that skips target entries with all zeros """ def __init__(self): super(BCESkipLoss, self).__init__() self.bce = nn.BCELoss() def forward(self, out, tgt): # mask contains 1s on timestamps with any events in target # and 0s on timestamps without them mask = tgt.max(2).values.unsqueeze(2).repeat(1, 1, tgt.shape[2]).detach() # mask out timestamps without events in output loss = self.bce(out * mask, tgt) return loss class Session2Session(pl.LightningModule): def __init__(self, n_events: int, max_len: int = 45): super(Session2Session, self).__init__() # layers d_model = n_events + 2 # + + nhead = 2 nlayers = 2 dim_feedforward = 128 encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward) self.encoder = nn.TransformerEncoder(encoder_layer, nlayers) decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward) self.decoder = nn.TransformerDecoder(decoder_layer, nlayers) self.tgt_mask = generate_square_subsequent_mask(max_len, attend_same=False) # training settings self.criterion = BCESkipLoss() def forward(self, src, tgt): memory = self.encoder(src) out = self.decoder(tgt, memory, self.tgt_mask.type_as(tgt)) out = torch.sigmoid(out) return out def training_step(self, batch, batch_nb): src, tgt = [x.permute(2, 0, 1) for x in batch] tgt_hat = self(src, tgt) loss = self.criterion(tgt_hat, tgt) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_nb): src, tgt = [x.permute(2, 0, 1) for x in batch] tgt_hat = self(src, tgt) loss = self.criterion(tgt_hat, tgt) self.log("val_loss", loss) return loss def test_step(self, batch, batch_nb): src, tgt = [x.permute(2, 0, 1) for x in batch] tgt_hat = self(src, tgt) loss = self.criterion(tgt_hat, tgt) self.log("test_loss", loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) def generate_square_subsequent_mask(sz, attend_same=True): if attend_same: mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) else: # default square mask attends to the element with the same index # which is not suitable for e.g. attention to target # thus we use shifted by 1 version of mask mask = torch.triu(torch.ones(sz + 1, sz)).transpose(0, 1)[:, 1:] # however we can't set a row of all -infs, so we hack the first row mask[0, 0] = 1 mask = mask.float().masked_fill(mask == 0, float('-inf')) mask = mask.masked_fill(mask == 1, float(0.0)) return mask