Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Created November 21, 2024 01:06
Show Gist options
  • Select an option

  • Save Algomancer/5a46eb8f0d9d09ce2ef1b1022ee36d93 to your computer and use it in GitHub Desktop.

Select an option

Save Algomancer/5a46eb8f0d9d09ce2ef1b1022ee36d93 to your computer and use it in GitHub Desktop.
class Smear(nn.Module):
def __init__(self, n_heads, seq_len):
super().__init__()
# 73% to the first key and 27% to the second key.
self.alpha_values = torch.nn.Parameter(torch.ones(1, n_heads, seq_len-1, 1))
def forward(self, k: torch.Tensor) -> torch.Tensor:
# k has shape (batch_size, n_heads, seq_len, d_k)
smeared_k = k[:, :, 1:, :]*(torch.sigmoid(self.alpha_values[:, :, 1:k.shape[2], :]))+ k[:, :, :-1, :]*(1-torch.sigmoid(self.alpha_values[:, :, -k.shape[2]:-1, :]))
return torch.cat([k[:, :, 0:1, :], smeared_k], dim=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment