Last active
April 20, 2024 16:29
-
-
Save sflender/f9b35df5905bba55992c8d2a518d0177 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SimpleSelfAttention(nn.Module): | |
| def __init__(self, embed_size): | |
| super(SimpleSelfAttention, self).__init__() | |
| self.embed_size = embed_size | |
| self.values = nn.Linear(embed_size, embed_size, bias=False) | |
| self.keys = nn.Linear(embed_size, embed_size, bias=False) | |
| self.queries = nn.Linear(embed_size, embed_size, bias=False) | |
| self.fc_out = nn.Linear(embed_size, embed_size) | |
| def forward(self, values, keys, query, mask=None): | |
| # Apply the linear transformations | |
| values = self.values(values) | |
| keys = self.keys(keys) | |
| queries = self.queries(query) | |
| # Compute the dot products between queries and keys for each example in the batch | |
| energy = torch.matmul(queries, keys.transpose(-2, -1)) / (self.embed_size ** 0.5) | |
| # Optionally apply mask to ignore certain positions (useful for padded sequences or future blinding) | |
| if mask is not None: | |
| energy = energy.masked_fill(mask == 0, float("-1e20")) | |
| # Attention (Softmax normalization of energies) | |
| attention = torch.softmax(energy, dim=-1) | |
| # Multiply attention weights by the values | |
| out = torch.matmul(attention, values) | |
| # Pass through a final linear transformation | |
| out = self.fc_out(out) | |
| return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment