Last active
August 15, 2024 13:40
-
-
Save MahmoudAshraf97/8a089e0a361ebd15978b3b04f866bf62 to your computer and use it in GitHub Desktop.
Reference Implementation of Silero V4 VAD model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class STFT(nn.Module): | |
| def __init__(self, filter_length, hop_length): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.padding = nn.ReflectionPad1d((filter_length - hop_length) // 2) | |
| self.register_buffer( | |
| "forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length]) | |
| ) | |
| def forward(self, input_data): | |
| input_data = self.padding(input_data).unsqueeze(1) | |
| forward_transform = F.conv1d( | |
| input_data, self.forward_basis_buffer, stride=self.hop_length | |
| ) | |
| cutoff = self.filter_length // 2 + 1 | |
| real_part = forward_transform[:, :cutoff] | |
| imag_part = forward_transform[:, cutoff:] | |
| magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2)) | |
| return magnitude | |
| class AdaptiveAudioNormalization(nn.Module): | |
| def __init__(self): | |
| super(AdaptiveAudioNormalization, self).__init__() | |
| self.padding = nn.ReflectionPad1d(3) | |
| self.filter = self.gaussian_kernel(7, 1.5).reshape(1, 1, -1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| log_spec = torch.log1p(x * 1048576) | |
| if log_spec.ndim == 2: | |
| log_spec = log_spec.unsqueeze(0) | |
| mean_log_spec = log_spec.mean(dim=1, keepdim=True) | |
| mean_log_spec = self.padding(mean_log_spec) | |
| mean_log_spec = torch.conv1d(mean_log_spec, self.filter) | |
| mean_log_spec = mean_log_spec.mean(dim=-1, keepdim=True) | |
| normalized = log_spec - mean_log_spec | |
| return normalized | |
| @staticmethod | |
| def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: | |
| kernel_range = torch.arange(-(size // 2), (size // 2) + 1) | |
| kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2) | |
| kernel = kernel / kernel.sum() | |
| return kernel | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 129, | |
| out_channels: int = 16, | |
| stride: int = 2, | |
| has_out_proj: bool = True, | |
| ): | |
| super(ConvBlock, self).__init__() | |
| self.relu = nn.ReLU() | |
| self.dw_conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=5, | |
| padding=2, | |
| groups=in_channels, | |
| ) | |
| self.pw_conv = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| if has_out_proj: | |
| self.proj = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| else: | |
| self.proj = nn.Identity() | |
| self.dropout = nn.Dropout(p=0.15) | |
| self.conv = nn.Conv1d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| ) | |
| self.batch_norm = nn.BatchNorm1d(out_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.dw_conv(x) | |
| x = self.relu(x) | |
| x = self.pw_conv(x) | |
| x += self.proj(residual) | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| x = self.conv(x) | |
| x = self.batch_norm(x) | |
| x = self.relu(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(Encoder, self).__init__() | |
| assert sampling_rate in [ | |
| 8000, | |
| 16000, | |
| ], "Supported sampling rates are 8000 and 16000" | |
| self.feature_extractor = STFT(256, 64) | |
| self.adaptive_normalization = AdaptiveAudioNormalization() | |
| self.conv_block_0 = ConvBlock( | |
| in_channels=258, out_channels=16, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_1 = ConvBlock( | |
| in_channels=16, out_channels=32, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_2 = ConvBlock( | |
| in_channels=32, | |
| out_channels=32, | |
| stride=2 if sampling_rate == 16000 else 1, | |
| has_out_proj=False, | |
| ) | |
| self.conv_block_3 = ConvBlock( | |
| in_channels=32, out_channels=64, stride=1, has_out_proj=True | |
| ) | |
| def forward(self, x): | |
| x_feature = self.feature_extractor(x) | |
| x_norm = self.adaptive_normalization(x_feature) | |
| x = torch.cat([x_feature, x_norm], 1) | |
| x = self.conv_block_0(x) | |
| x = self.conv_block_1(x) | |
| x = self.conv_block_2(x) | |
| x = self.conv_block_3(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| self.rnn = torch.nn.LSTM( | |
| input_size=64, hidden_size=64, num_layers=2, batch_first=True, dropout=0.1 | |
| ) | |
| self.dropout = nn.Dropout(p=0.1) | |
| self.relu = nn.ReLU() | |
| self.conv1d = torch.nn.Conv1d(in_channels=64, out_channels=1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input, h, c): | |
| input, (h, c) = self.rnn(input.permute([0, 2, 1]), (h, c)) | |
| input = self.dropout(input.permute([0, 2, 1])) | |
| input = self.relu(input) | |
| input = self.conv1d(input) | |
| input = self.sigmoid(input) | |
| return input, h, c | |
| class VadModel(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(VadModel, self).__init__() | |
| self.encoder = Encoder(sampling_rate) | |
| self.decoder = Decoder() | |
| @staticmethod | |
| def get_initial_states(batch_size: int, device: torch.device): | |
| h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| return h, c | |
| def single_step_forward(self, x, h, c): | |
| x = self.encoder(x) | |
| out, h, c = self.decoder(x, h, c) | |
| out = torch.mean(out, -1) | |
| return out, h, c | |
| def forward(self, x, num_samples=512): | |
| assert ( | |
| x.ndim == 2 | |
| ), "Input should be a 2D tensor with size (batch_size, num_samples)" | |
| num_audio = x.size(0) | |
| h, c = self.get_initial_states(num_audio, device=x.device) | |
| x = x.reshape(-1, num_samples) | |
| x = self.encoder(x) | |
| x = x.reshape(num_audio, -1, 64, x.size(-1)) | |
| decoder_outputs = [] | |
| for window in x.unbind(1): | |
| out, h, c = self.decoder(window, h, c) | |
| decoder_outputs.append(out) | |
| out = torch.stack(decoder_outputs, dim=1).squeeze(-1) | |
| return out.mean(dim=-1) | |
| model_mapping = { | |
| "encoder.feature_extractor.forward_basis_buffer": "_model.feature_extractor.forward_basis_buffer", | |
| "encoder.conv_block_0.dw_conv.weight": "_model.first_layer.0.dw_conv.0.weight", | |
| "encoder.conv_block_0.dw_conv.bias": "_model.first_layer.0.dw_conv.0.bias", | |
| "encoder.conv_block_0.pw_conv.weight": "_model.first_layer.0.pw_conv.0.weight", | |
| "encoder.conv_block_0.pw_conv.bias": "_model.first_layer.0.pw_conv.0.bias", | |
| "encoder.conv_block_0.proj.weight": "_model.first_layer.0.proj.weight", | |
| "encoder.conv_block_0.proj.bias": "_model.first_layer.0.proj.bias", | |
| "encoder.conv_block_0.conv.weight": "_model.encoder.0.weight", | |
| "encoder.conv_block_0.conv.bias": "_model.encoder.0.bias", | |
| "encoder.conv_block_0.batch_norm.weight": "_model.encoder.1.weight", | |
| "encoder.conv_block_0.batch_norm.bias": "_model.encoder.1.bias", | |
| "encoder.conv_block_0.batch_norm.running_mean": "_model.encoder.1.running_mean", | |
| "encoder.conv_block_0.batch_norm.running_var": "_model.encoder.1.running_var", | |
| "encoder.conv_block_0.batch_norm.num_batches_tracked": "_model.encoder.1.num_batches_tracked", | |
| "encoder.conv_block_1.dw_conv.weight": "_model.encoder.3.0.dw_conv.0.weight", | |
| "encoder.conv_block_1.dw_conv.bias": "_model.encoder.3.0.dw_conv.0.bias", | |
| "encoder.conv_block_1.pw_conv.weight": "_model.encoder.3.0.pw_conv.0.weight", | |
| "encoder.conv_block_1.pw_conv.bias": "_model.encoder.3.0.pw_conv.0.bias", | |
| "encoder.conv_block_1.proj.weight": "_model.encoder.3.0.proj.weight", | |
| "encoder.conv_block_1.proj.bias": "_model.encoder.3.0.proj.bias", | |
| "encoder.conv_block_1.conv.weight": "_model.encoder.4.weight", | |
| "encoder.conv_block_1.conv.bias": "_model.encoder.4.bias", | |
| "encoder.conv_block_1.batch_norm.weight": "_model.encoder.5.weight", | |
| "encoder.conv_block_1.batch_norm.bias": "_model.encoder.5.bias", | |
| "encoder.conv_block_1.batch_norm.running_mean": "_model.encoder.5.running_mean", | |
| "encoder.conv_block_1.batch_norm.running_var": "_model.encoder.5.running_var", | |
| "encoder.conv_block_1.batch_norm.num_batches_tracked": "_model.encoder.5.num_batches_tracked", | |
| "encoder.conv_block_2.dw_conv.weight": "_model.encoder.7.0.dw_conv.0.weight", | |
| "encoder.conv_block_2.dw_conv.bias": "_model.encoder.7.0.dw_conv.0.bias", | |
| "encoder.conv_block_2.pw_conv.weight": "_model.encoder.7.0.pw_conv.0.weight", | |
| "encoder.conv_block_2.pw_conv.bias": "_model.encoder.7.0.pw_conv.0.bias", | |
| "encoder.conv_block_2.conv.weight": "_model.encoder.8.weight", | |
| "encoder.conv_block_2.conv.bias": "_model.encoder.8.bias", | |
| "encoder.conv_block_2.batch_norm.weight": "_model.encoder.9.weight", | |
| "encoder.conv_block_2.batch_norm.bias": "_model.encoder.9.bias", | |
| "encoder.conv_block_2.batch_norm.running_mean": "_model.encoder.9.running_mean", | |
| "encoder.conv_block_2.batch_norm.running_var": "_model.encoder.9.running_var", | |
| "encoder.conv_block_2.batch_norm.num_batches_tracked": "_model.encoder.9.num_batches_tracked", | |
| "encoder.conv_block_3.dw_conv.weight": "_model.encoder.11.0.dw_conv.0.weight", | |
| "encoder.conv_block_3.dw_conv.bias": "_model.encoder.11.0.dw_conv.0.bias", | |
| "encoder.conv_block_3.pw_conv.weight": "_model.encoder.11.0.pw_conv.0.weight", | |
| "encoder.conv_block_3.pw_conv.bias": "_model.encoder.11.0.pw_conv.0.bias", | |
| "encoder.conv_block_3.proj.weight": "_model.encoder.11.0.proj.weight", | |
| "encoder.conv_block_3.proj.bias": "_model.encoder.11.0.proj.bias", | |
| "encoder.conv_block_3.conv.weight": "_model.encoder.12.weight", | |
| "encoder.conv_block_3.conv.bias": "_model.encoder.12.bias", | |
| "encoder.conv_block_3.batch_norm.weight": "_model.encoder.13.weight", | |
| "encoder.conv_block_3.batch_norm.bias": "_model.encoder.13.bias", | |
| "encoder.conv_block_3.batch_norm.running_mean": "_model.encoder.13.running_mean", | |
| "encoder.conv_block_3.batch_norm.running_var": "_model.encoder.13.running_var", | |
| "encoder.conv_block_3.batch_norm.num_batches_tracked": "_model.encoder.13.num_batches_tracked", | |
| "decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih_l0", | |
| "decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh_l0", | |
| "decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih_l0", | |
| "decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh_l0", | |
| "decoder.rnn.weight_ih_l1": "_model.decoder.rnn.weight_ih_l1", | |
| "decoder.rnn.weight_hh_l1": "_model.decoder.rnn.weight_hh_l1", | |
| "decoder.rnn.bias_ih_l1": "_model.decoder.rnn.bias_ih_l1", | |
| "decoder.rnn.bias_hh_l1": "_model.decoder.rnn.bias_hh_l1", | |
| "decoder.conv1d.weight": "_model.decoder.decoder.1.weight", | |
| "decoder.conv1d.bias": "_model.decoder.decoder.1.bias", | |
| } | |
| sampling_rate = 16000 | |
| new_model = VadModel(sampling_rate) | |
| new_model = new_model.eval() | |
| if sampling_rate == 8000: | |
| model_mapping = { | |
| key: value.replace("_model", "_model_8k") | |
| for key, value in model_mapping.items() | |
| } | |
| model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad:v4.0stable", | |
| model="silero_vad", | |
| force_reload=True, | |
| ) | |
| model = model.eval() | |
| new_model.load_state_dict( | |
| {key: model.state_dict()[model_mapping[key]] for key in model_mapping.keys()} | |
| ) | |
| torch.onnx.export( | |
| new_model.encoder, | |
| {"x": torch.randn(10, 512, dtype=torch.float32)}, | |
| f=f"encoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input"], | |
| dynamic_axes={"input": {0: "batch", 1: "num_samples"}}, | |
| ) | |
| torch.onnx.export( | |
| new_model.decoder, | |
| dict( | |
| input=torch.randn(1, 64, 2, dtype=torch.float32), | |
| h=torch.rand((2, 1, 64), dtype=torch.float32), | |
| c=torch.rand((2, 1, 64), dtype=torch.float32), | |
| ), | |
| f=f"decoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input", "h", "c"], | |
| dynamic_axes={ | |
| "input": {0: "batch", 2: "num_samples"}, | |
| "h": {1: "batch"}, | |
| "c": {1: "batch"}, | |
| }, | |
| ) |
Author
I revised the implementation again and fixed your notes, also supported variable num_samples and 8KHz
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I didn't check the decoder part yet but the
VadModelclass should be sth like the following:and then the outputs of the encoder can be validated as follows: