Skip to content

Instantly share code, notes, and snippets.

@sflender
Last active April 20, 2024 16:29
Show Gist options
  • Select an option

  • Save sflender/f9b35df5905bba55992c8d2a518d0177 to your computer and use it in GitHub Desktop.

Select an option

Save sflender/f9b35df5905bba55992c8d2a518d0177 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleSelfAttention(nn.Module):
def __init__(self, embed_size):
super(SimpleSelfAttention, self).__init__()
self.embed_size = embed_size
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.queries = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, values, keys, query, mask=None):
# Apply the linear transformations
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
# Compute the dot products between queries and keys for each example in the batch
energy = torch.matmul(queries, keys.transpose(-2, -1)) / (self.embed_size ** 0.5)
# Optionally apply mask to ignore certain positions (useful for padded sequences or future blinding)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# Attention (Softmax normalization of energies)
attention = torch.softmax(energy, dim=-1)
# Multiply attention weights by the values
out = torch.matmul(attention, values)
# Pass through a final linear transformation
out = self.fc_out(out)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment