Last active
May 8, 2025 23:25
-
-
Save pohanchi/c77f6dbfbcbc21c5215acde4f62e4362 to your computer and use it in GitHub Desktop.
Self-attention-pooling module for speaker classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| class SelfAttentionPooling(nn.Module): | |
| """ | |
| Implementation of SelfAttentionPooling | |
| Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition | |
| https://arxiv.org/pdf/2008.01077v1.pdf | |
| """ | |
| def __init__(self, input_dim): | |
| super(SelfAttentionPooling, self).__init__() | |
| self.W = nn.Linear(input_dim, 1) | |
| def forward(self, batch_rep): | |
| """ | |
| input: | |
| batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension | |
| attention_weight: | |
| att_w : size (N, T, 1) | |
| return: | |
| utter_rep: size (N, H) | |
| """ | |
| softmax = nn.functional.softmax | |
| att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1) | |
| utter_rep = torch.sum(batch_rep * att_w, dim=1) | |
| return utter_rep |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment