Facebook
From Bistre Tamarin, 3 Years ago, written in Python.
Embed
Download Paste or View Raw
Hits: 197
  1. import torch.nn as nn
  2. import pytorch_lightning as pl
  3.  
  4.  
  5. class BCESkipLoss(nn.Module):
  6.     """
  7.    BCE loss that skips target entries with all zeros
  8.    """
  9.  
  10.     def __init__(self):
  11.         super(BCESkipLoss, self).__init__()
  12.         self.bce = nn.BCELoss()
  13.  
  14.     def forward(self, out, tgt):
  15.         # mask contains 1s on timestamps with any events in target
  16.         # and 0s on timestamps without them
  17.         mask = tgt.max(2).values.unsqueeze(2).repeat(1, 1, tgt.shape[2]).detach()
  18.         # mask out timestamps without events in output
  19.         loss = self.bce(out * mask, tgt)
  20.         return loss
  21.  
  22.  
  23. class Session2Session(pl.LightningModule):
  24.  
  25.     def __init__(self, n_events: int, max_len: int = 45):
  26.         super(Session2Session, self).__init__()
  27.         # layers
  28.         d_model = n_events + 2   # + <sos> + <eos>
  29.         nhead = 2
  30.         nlayers = 2
  31.         dim_feedforward = 128
  32.         encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
  33.         self.encoder = nn.TransformerEncoder(encoder_layer, nlayers)
  34.         decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
  35.         self.decoder = nn.TransformerDecoder(decoder_layer, nlayers)
  36.         self.tgt_mask = generate_square_subsequent_mask(max_len, attend_same=False)
  37.         # training settings
  38.         self.criterion = BCESkipLoss()
  39.  
  40.     def forward(self, src, tgt):
  41.         memory = self.encoder(src)
  42.         out = self.decoder(tgt, memory, self.tgt_mask.type_as(tgt))
  43.         out = torch.sigmoid(out)
  44.         return out
  45.  
  46.     def training_step(self, batch, batch_nb):
  47.         src, tgt = [x.permute(2, 0, 1) for x in batch]
  48.         tgt_hat = self(src, tgt)
  49.         loss = self.criterion(tgt_hat, tgt)
  50.         self.log("train_loss", loss)
  51.         return loss
  52.  
  53.     def validation_step(self, batch, batch_nb):
  54.         src, tgt = [x.permute(2, 0, 1) for x in batch]
  55.         tgt_hat = self(src, tgt)
  56.         loss = self.criterion(tgt_hat, tgt)
  57.         self.log("val_loss", loss)
  58.         return loss
  59.  
  60.     def test_step(self, batch, batch_nb):
  61.         src, tgt = [x.permute(2, 0, 1) for x in batch]
  62.         tgt_hat = self(src, tgt)
  63.         loss = self.criterion(tgt_hat, tgt)
  64.         self.log("test_loss", loss)
  65.         return loss
  66.  
  67.     def configure_optimizers(self):
  68.         return torch.optim.Adam(self.parameters(), lr=1e-3)
  69.  
  70.  
  71. def generate_square_subsequent_mask(sz, attend_same=True):
  72.     if attend_same:
  73.         mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  74.     else:
  75.         # default square mask attends to the element with the same index
  76.         # which is not suitable for e.g. attention to target
  77.         # thus we use shifted by 1 version of mask
  78.         mask = torch.triu(torch.ones(sz + 1, sz)).transpose(0, 1)[:, 1:]
  79.         # however we can't set a row of all -infs, so we hack the first row
  80.         mask[0, 0] = 1
  81.     mask = mask.float().masked_fill(mask == 0, float('-inf'))
  82.     mask = mask.masked_fill(mask == 1, float(0.0))
  83.     return mask