Skip to content

Instantly share code, notes, and snippets.

@shark8me
Created September 16, 2022 13:45
Show Gist options
  • Select an option

  • Save shark8me/dc13cd39e7d16df4bb89c529cd5b4e52 to your computer and use it in GitHub Desktop.

Select an option

Save shark8me/dc13cd39e7d16df4bb89c529cd5b4e52 to your computer and use it in GitHub Desktop.
Default convolution stack for AMT
class ConvStack2d(nn.Module):
def __init__(self, input_features, output_features):
super().__init__()
# input is batch_size * 1 channel * frames * input_features
self.cnn = nn.Sequential(
# layer 0
nn.Conv2d(1, output_features // 16, (3, 3), padding=1),
nn.BatchNorm2d(output_features // 16),
nn.ReLU(),
# layer 1
nn.Conv2d(output_features // 16, output_features // 16, (3, 3), padding=1),
nn.BatchNorm2d(output_features // 16),
nn.ReLU(),
# layer 2
nn.MaxPool2d((1, 2)),
nn.Dropout(0.25),
nn.Conv2d(output_features // 16, output_features // 8, (3, 3), padding=1),
nn.BatchNorm2d(output_features // 8),
nn.ReLU(),
# layer 3
nn.MaxPool2d((1, 2)),
nn.Dropout(0.25),
)
self.fc = nn.Sequential(
nn.Linear((output_features // 8) * (input_features // 4), output_features),
nn.Linear(output_features,1)
)
self.fc2 = nn.Sequential(
nn.Linear(num_steps,output_features)
)
def forward(self, mel):
x = mel.view(mel.size(0), 1, mel.size(1), mel.size(2))
x = self.cnn(x)
x = x.transpose(1, 2).flatten(-2)
x = self.fc(x)
x= torch.squeeze(x)
x = self.fc2(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment