Skip to content

Instantly share code, notes, and snippets.

@pohanchi
Last active May 8, 2025 23:25
Show Gist options
  • Select an option

  • Save pohanchi/c77f6dbfbcbc21c5215acde4f62e4362 to your computer and use it in GitHub Desktop.

Select an option

Save pohanchi/c77f6dbfbcbc21c5215acde4f62e4362 to your computer and use it in GitHub Desktop.
Self-attention-pooling module for speaker classification
import torch
from torch import nn
class SelfAttentionPooling(nn.Module):
"""
Implementation of SelfAttentionPooling
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
https://arxiv.org/pdf/2008.01077v1.pdf
"""
def __init__(self, input_dim):
super(SelfAttentionPooling, self).__init__()
self.W = nn.Linear(input_dim, 1)
def forward(self, batch_rep):
"""
input:
batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
attention_weight:
att_w : size (N, T, 1)
return:
utter_rep: size (N, H)
"""
softmax = nn.functional.softmax
att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
utter_rep = torch.sum(batch_rep * att_w, dim=1)
return utter_rep
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment