Last active
August 15, 2024 13:40
-
-
Save MahmoudAshraf97/8a089e0a361ebd15978b3b04f866bf62 to your computer and use it in GitHub Desktop.
Reference Implementation of Silero V4 VAD model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class STFT(nn.Module): | |
| def __init__(self, filter_length, hop_length): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.padding = nn.ReflectionPad1d((filter_length - hop_length) // 2) | |
| self.register_buffer( | |
| "forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length]) | |
| ) | |
| def forward(self, input_data): | |
| input_data = self.padding(input_data).unsqueeze(1) | |
| forward_transform = F.conv1d( | |
| input_data, self.forward_basis_buffer, stride=self.hop_length | |
| ) | |
| cutoff = self.filter_length // 2 + 1 | |
| real_part = forward_transform[:, :cutoff] | |
| imag_part = forward_transform[:, cutoff:] | |
| magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2)) | |
| return magnitude | |
| class AdaptiveAudioNormalization(nn.Module): | |
| def __init__(self): | |
| super(AdaptiveAudioNormalization, self).__init__() | |
| self.padding = nn.ReflectionPad1d(3) | |
| self.filter = self.gaussian_kernel(7, 1.5).reshape(1, 1, -1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| log_spec = torch.log1p(x * 1048576) | |
| if log_spec.ndim == 2: | |
| log_spec = log_spec.unsqueeze(0) | |
| mean_log_spec = log_spec.mean(dim=1, keepdim=True) | |
| mean_log_spec = self.padding(mean_log_spec) | |
| mean_log_spec = torch.conv1d(mean_log_spec, self.filter) | |
| mean_log_spec = mean_log_spec.mean(dim=-1, keepdim=True) | |
| normalized = log_spec - mean_log_spec | |
| return normalized | |
| @staticmethod | |
| def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: | |
| kernel_range = torch.arange(-(size // 2), (size // 2) + 1) | |
| kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2) | |
| kernel = kernel / kernel.sum() | |
| return kernel | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 129, | |
| out_channels: int = 16, | |
| stride: int = 2, | |
| has_out_proj: bool = True, | |
| ): | |
| super(ConvBlock, self).__init__() | |
| self.relu = nn.ReLU() | |
| self.dw_conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=5, | |
| padding=2, | |
| groups=in_channels, | |
| ) | |
| self.pw_conv = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| if has_out_proj: | |
| self.proj = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| else: | |
| self.proj = nn.Identity() | |
| self.dropout = nn.Dropout(p=0.15) | |
| self.conv = nn.Conv1d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| ) | |
| self.batch_norm = nn.BatchNorm1d(out_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.dw_conv(x) | |
| x = self.relu(x) | |
| x = self.pw_conv(x) | |
| x += self.proj(residual) | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| x = self.conv(x) | |
| x = self.batch_norm(x) | |
| x = self.relu(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(Encoder, self).__init__() | |
| assert sampling_rate in [ | |
| 8000, | |
| 16000, | |
| ], "Supported sampling rates are 8000 and 16000" | |
| self.feature_extractor = STFT(256, 64) | |
| self.adaptive_normalization = AdaptiveAudioNormalization() | |
| self.conv_block_0 = ConvBlock( | |
| in_channels=258, out_channels=16, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_1 = ConvBlock( | |
| in_channels=16, out_channels=32, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_2 = ConvBlock( | |
| in_channels=32, | |
| out_channels=32, | |
| stride=2 if sampling_rate == 16000 else 1, | |
| has_out_proj=False, | |
| ) | |
| self.conv_block_3 = ConvBlock( | |
| in_channels=32, out_channels=64, stride=1, has_out_proj=True | |
| ) | |
| def forward(self, x): | |
| x_feature = self.feature_extractor(x) | |
| x_norm = self.adaptive_normalization(x_feature) | |
| x = torch.cat([x_feature, x_norm], 1) | |
| x = self.conv_block_0(x) | |
| x = self.conv_block_1(x) | |
| x = self.conv_block_2(x) | |
| x = self.conv_block_3(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| self.rnn = torch.nn.LSTM( | |
| input_size=64, hidden_size=64, num_layers=2, batch_first=True, dropout=0.1 | |
| ) | |
| self.dropout = nn.Dropout(p=0.1) | |
| self.relu = nn.ReLU() | |
| self.conv1d = torch.nn.Conv1d(in_channels=64, out_channels=1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input, h, c): | |
| input, (h, c) = self.rnn(input.permute([0, 2, 1]), (h, c)) | |
| input = self.dropout(input.permute([0, 2, 1])) | |
| input = self.relu(input) | |
| input = self.conv1d(input) | |
| input = self.sigmoid(input) | |
| return input, h, c | |
| class VadModel(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(VadModel, self).__init__() | |
| self.encoder = Encoder(sampling_rate) | |
| self.decoder = Decoder() | |
| @staticmethod | |
| def get_initial_states(batch_size: int, device: torch.device): | |
| h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| return h, c | |
| def single_step_forward(self, x, h, c): | |
| x = self.encoder(x) | |
| out, h, c = self.decoder(x, h, c) | |
| out = torch.mean(out, -1) | |
| return out, h, c | |
| def forward(self, x, num_samples=512): | |
| assert ( | |
| x.ndim == 2 | |
| ), "Input should be a 2D tensor with size (batch_size, num_samples)" | |
| num_audio = x.size(0) | |
| h, c = self.get_initial_states(num_audio, device=x.device) | |
| x = x.reshape(-1, num_samples) | |
| x = self.encoder(x) | |
| x = x.reshape(num_audio, -1, 64, x.size(-1)) | |
| decoder_outputs = [] | |
| for window in x.unbind(1): | |
| out, h, c = self.decoder(window, h, c) | |
| decoder_outputs.append(out) | |
| out = torch.stack(decoder_outputs, dim=1).squeeze(-1) | |
| return out.mean(dim=-1) | |
| model_mapping = { | |
| "encoder.feature_extractor.forward_basis_buffer": "_model.feature_extractor.forward_basis_buffer", | |
| "encoder.conv_block_0.dw_conv.weight": "_model.first_layer.0.dw_conv.0.weight", | |
| "encoder.conv_block_0.dw_conv.bias": "_model.first_layer.0.dw_conv.0.bias", | |
| "encoder.conv_block_0.pw_conv.weight": "_model.first_layer.0.pw_conv.0.weight", | |
| "encoder.conv_block_0.pw_conv.bias": "_model.first_layer.0.pw_conv.0.bias", | |
| "encoder.conv_block_0.proj.weight": "_model.first_layer.0.proj.weight", | |
| "encoder.conv_block_0.proj.bias": "_model.first_layer.0.proj.bias", | |
| "encoder.conv_block_0.conv.weight": "_model.encoder.0.weight", | |
| "encoder.conv_block_0.conv.bias": "_model.encoder.0.bias", | |
| "encoder.conv_block_0.batch_norm.weight": "_model.encoder.1.weight", | |
| "encoder.conv_block_0.batch_norm.bias": "_model.encoder.1.bias", | |
| "encoder.conv_block_0.batch_norm.running_mean": "_model.encoder.1.running_mean", | |
| "encoder.conv_block_0.batch_norm.running_var": "_model.encoder.1.running_var", | |
| "encoder.conv_block_0.batch_norm.num_batches_tracked": "_model.encoder.1.num_batches_tracked", | |
| "encoder.conv_block_1.dw_conv.weight": "_model.encoder.3.0.dw_conv.0.weight", | |
| "encoder.conv_block_1.dw_conv.bias": "_model.encoder.3.0.dw_conv.0.bias", | |
| "encoder.conv_block_1.pw_conv.weight": "_model.encoder.3.0.pw_conv.0.weight", | |
| "encoder.conv_block_1.pw_conv.bias": "_model.encoder.3.0.pw_conv.0.bias", | |
| "encoder.conv_block_1.proj.weight": "_model.encoder.3.0.proj.weight", | |
| "encoder.conv_block_1.proj.bias": "_model.encoder.3.0.proj.bias", | |
| "encoder.conv_block_1.conv.weight": "_model.encoder.4.weight", | |
| "encoder.conv_block_1.conv.bias": "_model.encoder.4.bias", | |
| "encoder.conv_block_1.batch_norm.weight": "_model.encoder.5.weight", | |
| "encoder.conv_block_1.batch_norm.bias": "_model.encoder.5.bias", | |
| "encoder.conv_block_1.batch_norm.running_mean": "_model.encoder.5.running_mean", | |
| "encoder.conv_block_1.batch_norm.running_var": "_model.encoder.5.running_var", | |
| "encoder.conv_block_1.batch_norm.num_batches_tracked": "_model.encoder.5.num_batches_tracked", | |
| "encoder.conv_block_2.dw_conv.weight": "_model.encoder.7.0.dw_conv.0.weight", | |
| "encoder.conv_block_2.dw_conv.bias": "_model.encoder.7.0.dw_conv.0.bias", | |
| "encoder.conv_block_2.pw_conv.weight": "_model.encoder.7.0.pw_conv.0.weight", | |
| "encoder.conv_block_2.pw_conv.bias": "_model.encoder.7.0.pw_conv.0.bias", | |
| "encoder.conv_block_2.conv.weight": "_model.encoder.8.weight", | |
| "encoder.conv_block_2.conv.bias": "_model.encoder.8.bias", | |
| "encoder.conv_block_2.batch_norm.weight": "_model.encoder.9.weight", | |
| "encoder.conv_block_2.batch_norm.bias": "_model.encoder.9.bias", | |
| "encoder.conv_block_2.batch_norm.running_mean": "_model.encoder.9.running_mean", | |
| "encoder.conv_block_2.batch_norm.running_var": "_model.encoder.9.running_var", | |
| "encoder.conv_block_2.batch_norm.num_batches_tracked": "_model.encoder.9.num_batches_tracked", | |
| "encoder.conv_block_3.dw_conv.weight": "_model.encoder.11.0.dw_conv.0.weight", | |
| "encoder.conv_block_3.dw_conv.bias": "_model.encoder.11.0.dw_conv.0.bias", | |
| "encoder.conv_block_3.pw_conv.weight": "_model.encoder.11.0.pw_conv.0.weight", | |
| "encoder.conv_block_3.pw_conv.bias": "_model.encoder.11.0.pw_conv.0.bias", | |
| "encoder.conv_block_3.proj.weight": "_model.encoder.11.0.proj.weight", | |
| "encoder.conv_block_3.proj.bias": "_model.encoder.11.0.proj.bias", | |
| "encoder.conv_block_3.conv.weight": "_model.encoder.12.weight", | |
| "encoder.conv_block_3.conv.bias": "_model.encoder.12.bias", | |
| "encoder.conv_block_3.batch_norm.weight": "_model.encoder.13.weight", | |
| "encoder.conv_block_3.batch_norm.bias": "_model.encoder.13.bias", | |
| "encoder.conv_block_3.batch_norm.running_mean": "_model.encoder.13.running_mean", | |
| "encoder.conv_block_3.batch_norm.running_var": "_model.encoder.13.running_var", | |
| "encoder.conv_block_3.batch_norm.num_batches_tracked": "_model.encoder.13.num_batches_tracked", | |
| "decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih_l0", | |
| "decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh_l0", | |
| "decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih_l0", | |
| "decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh_l0", | |
| "decoder.rnn.weight_ih_l1": "_model.decoder.rnn.weight_ih_l1", | |
| "decoder.rnn.weight_hh_l1": "_model.decoder.rnn.weight_hh_l1", | |
| "decoder.rnn.bias_ih_l1": "_model.decoder.rnn.bias_ih_l1", | |
| "decoder.rnn.bias_hh_l1": "_model.decoder.rnn.bias_hh_l1", | |
| "decoder.conv1d.weight": "_model.decoder.decoder.1.weight", | |
| "decoder.conv1d.bias": "_model.decoder.decoder.1.bias", | |
| } | |
| sampling_rate = 16000 | |
| new_model = VadModel(sampling_rate) | |
| new_model = new_model.eval() | |
| if sampling_rate == 8000: | |
| model_mapping = { | |
| key: value.replace("_model", "_model_8k") | |
| for key, value in model_mapping.items() | |
| } | |
| model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad:v4.0stable", | |
| model="silero_vad", | |
| force_reload=True, | |
| ) | |
| model = model.eval() | |
| new_model.load_state_dict( | |
| {key: model.state_dict()[model_mapping[key]] for key in model_mapping.keys()} | |
| ) | |
| torch.onnx.export( | |
| new_model.encoder, | |
| {"x": torch.randn(10, 512, dtype=torch.float32)}, | |
| f=f"encoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input"], | |
| dynamic_axes={"input": {0: "batch", 1: "num_samples"}}, | |
| ) | |
| torch.onnx.export( | |
| new_model.decoder, | |
| dict( | |
| input=torch.randn(1, 64, 2, dtype=torch.float32), | |
| h=torch.rand((2, 1, 64), dtype=torch.float32), | |
| c=torch.rand((2, 1, 64), dtype=torch.float32), | |
| ), | |
| f=f"decoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input", "h", "c"], | |
| dynamic_axes={ | |
| "input": {0: "batch", 2: "num_samples"}, | |
| "h": {1: "batch"}, | |
| "c": {1: "batch"}, | |
| }, | |
| ) |
Author
I'm really not sure. All I see is that inputs are the chunks and the states are updated through h and c
https://github.com/snakers4/silero-vad/blob/915dd3d639b8333a52e001af095f87c5b7f1e0ac/utils_vad.py#L86
(I'm not talking about the internal concatenation of x_feature and its normalized version btw)
I didn't check the decoder part yet but the VadModel class should be sth like the following:
class VadModel(nn.Module):
def __init__(self):
super(VadModel, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def get_initial_states(self, batch_size: int, device: torch.device):
h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
return h, c
def forward(self, x):
bsz = x.size(0)
h, c = self.get_initial_states(bsz, device=x.device)
x = self.encoder(x).reshape(bsz, -1, 64, 2)
decoder_outputs = []
for window in x.unbind(1):
out, h, c = self.decoder(window, h, c)
decoder_outputs.append(out)
out = torch.stack(decoder_outputs, dim=1).squeeze(-1)
return out.mean(dim=-1)and then the outputs of the encoder can be validated as follows:
# sample input
x = torch.randn(1, 512, dtype=torch.float32)
# validate encoder outputs
jit_out = model.forward(x)
x0 = model.feature_extractor(x)
norm = model.adaptive_normalization(x0)
x1 = torch.cat([x0, norm], 1)
x2 = model.first_layer(x1)
enc = model.encoder(x2)
vad_enc = new_model.encoder(x)
assert torch.allclose(enc, vad_enc, atol=1e-5)
Author
I revised the implementation again and fixed your notes, also supported variable num_samples and 8KHz
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It does concatenate the context, you can refer to the onnx example on the original repo and in faster-whisper before v5 implementation