-
-
Save MahmoudAshraf97/8a089e0a361ebd15978b3b04f866bf62 to your computer and use it in GitHub Desktop.
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class STFT(nn.Module): | |
| def __init__(self, filter_length, hop_length): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.padding = nn.ReflectionPad1d((filter_length - hop_length) // 2) | |
| self.register_buffer( | |
| "forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length]) | |
| ) | |
| def forward(self, input_data): | |
| input_data = self.padding(input_data).unsqueeze(1) | |
| forward_transform = F.conv1d( | |
| input_data, self.forward_basis_buffer, stride=self.hop_length | |
| ) | |
| cutoff = self.filter_length // 2 + 1 | |
| real_part = forward_transform[:, :cutoff] | |
| imag_part = forward_transform[:, cutoff:] | |
| magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2)) | |
| return magnitude | |
| class AdaptiveAudioNormalization(nn.Module): | |
| def __init__(self): | |
| super(AdaptiveAudioNormalization, self).__init__() | |
| self.padding = nn.ReflectionPad1d(3) | |
| self.filter = self.gaussian_kernel(7, 1.5).reshape(1, 1, -1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| log_spec = torch.log1p(x * 1048576) | |
| if log_spec.ndim == 2: | |
| log_spec = log_spec.unsqueeze(0) | |
| mean_log_spec = log_spec.mean(dim=1, keepdim=True) | |
| mean_log_spec = self.padding(mean_log_spec) | |
| mean_log_spec = torch.conv1d(mean_log_spec, self.filter) | |
| mean_log_spec = mean_log_spec.mean(dim=-1, keepdim=True) | |
| normalized = log_spec - mean_log_spec | |
| return normalized | |
| @staticmethod | |
| def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: | |
| kernel_range = torch.arange(-(size // 2), (size // 2) + 1) | |
| kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2) | |
| kernel = kernel / kernel.sum() | |
| return kernel | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 129, | |
| out_channels: int = 16, | |
| stride: int = 2, | |
| has_out_proj: bool = True, | |
| ): | |
| super(ConvBlock, self).__init__() | |
| self.relu = nn.ReLU() | |
| self.dw_conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=5, | |
| padding=2, | |
| groups=in_channels, | |
| ) | |
| self.pw_conv = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| if has_out_proj: | |
| self.proj = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| else: | |
| self.proj = nn.Identity() | |
| self.dropout = nn.Dropout(p=0.15) | |
| self.conv = nn.Conv1d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| ) | |
| self.batch_norm = nn.BatchNorm1d(out_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.dw_conv(x) | |
| x = self.relu(x) | |
| x = self.pw_conv(x) | |
| x += self.proj(residual) | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| x = self.conv(x) | |
| x = self.batch_norm(x) | |
| x = self.relu(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(Encoder, self).__init__() | |
| assert sampling_rate in [ | |
| 8000, | |
| 16000, | |
| ], "Supported sampling rates are 8000 and 16000" | |
| self.feature_extractor = STFT(256, 64) | |
| self.adaptive_normalization = AdaptiveAudioNormalization() | |
| self.conv_block_0 = ConvBlock( | |
| in_channels=258, out_channels=16, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_1 = ConvBlock( | |
| in_channels=16, out_channels=32, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_2 = ConvBlock( | |
| in_channels=32, | |
| out_channels=32, | |
| stride=2 if sampling_rate == 16000 else 1, | |
| has_out_proj=False, | |
| ) | |
| self.conv_block_3 = ConvBlock( | |
| in_channels=32, out_channels=64, stride=1, has_out_proj=True | |
| ) | |
| def forward(self, x): | |
| x_feature = self.feature_extractor(x) | |
| x_norm = self.adaptive_normalization(x_feature) | |
| x = torch.cat([x_feature, x_norm], 1) | |
| x = self.conv_block_0(x) | |
| x = self.conv_block_1(x) | |
| x = self.conv_block_2(x) | |
| x = self.conv_block_3(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| self.rnn = torch.nn.LSTM( | |
| input_size=64, hidden_size=64, num_layers=2, batch_first=True, dropout=0.1 | |
| ) | |
| self.dropout = nn.Dropout(p=0.1) | |
| self.relu = nn.ReLU() | |
| self.conv1d = torch.nn.Conv1d(in_channels=64, out_channels=1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input, h, c): | |
| input, (h, c) = self.rnn(input.permute([0, 2, 1]), (h, c)) | |
| input = self.dropout(input.permute([0, 2, 1])) | |
| input = self.relu(input) | |
| input = self.conv1d(input) | |
| input = self.sigmoid(input) | |
| return input, h, c | |
| class VadModel(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(VadModel, self).__init__() | |
| self.encoder = Encoder(sampling_rate) | |
| self.decoder = Decoder() | |
| @staticmethod | |
| def get_initial_states(batch_size: int, device: torch.device): | |
| h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| return h, c | |
| def single_step_forward(self, x, h, c): | |
| x = self.encoder(x) | |
| out, h, c = self.decoder(x, h, c) | |
| out = torch.mean(out, -1) | |
| return out, h, c | |
| def forward(self, x, num_samples=512): | |
| assert ( | |
| x.ndim == 2 | |
| ), "Input should be a 2D tensor with size (batch_size, num_samples)" | |
| num_audio = x.size(0) | |
| h, c = self.get_initial_states(num_audio, device=x.device) | |
| x = x.reshape(-1, num_samples) | |
| x = self.encoder(x) | |
| x = x.reshape(num_audio, -1, 64, x.size(-1)) | |
| decoder_outputs = [] | |
| for window in x.unbind(1): | |
| out, h, c = self.decoder(window, h, c) | |
| decoder_outputs.append(out) | |
| out = torch.stack(decoder_outputs, dim=1).squeeze(-1) | |
| return out.mean(dim=-1) | |
| model_mapping = { | |
| "encoder.feature_extractor.forward_basis_buffer": "_model.feature_extractor.forward_basis_buffer", | |
| "encoder.conv_block_0.dw_conv.weight": "_model.first_layer.0.dw_conv.0.weight", | |
| "encoder.conv_block_0.dw_conv.bias": "_model.first_layer.0.dw_conv.0.bias", | |
| "encoder.conv_block_0.pw_conv.weight": "_model.first_layer.0.pw_conv.0.weight", | |
| "encoder.conv_block_0.pw_conv.bias": "_model.first_layer.0.pw_conv.0.bias", | |
| "encoder.conv_block_0.proj.weight": "_model.first_layer.0.proj.weight", | |
| "encoder.conv_block_0.proj.bias": "_model.first_layer.0.proj.bias", | |
| "encoder.conv_block_0.conv.weight": "_model.encoder.0.weight", | |
| "encoder.conv_block_0.conv.bias": "_model.encoder.0.bias", | |
| "encoder.conv_block_0.batch_norm.weight": "_model.encoder.1.weight", | |
| "encoder.conv_block_0.batch_norm.bias": "_model.encoder.1.bias", | |
| "encoder.conv_block_0.batch_norm.running_mean": "_model.encoder.1.running_mean", | |
| "encoder.conv_block_0.batch_norm.running_var": "_model.encoder.1.running_var", | |
| "encoder.conv_block_0.batch_norm.num_batches_tracked": "_model.encoder.1.num_batches_tracked", | |
| "encoder.conv_block_1.dw_conv.weight": "_model.encoder.3.0.dw_conv.0.weight", | |
| "encoder.conv_block_1.dw_conv.bias": "_model.encoder.3.0.dw_conv.0.bias", | |
| "encoder.conv_block_1.pw_conv.weight": "_model.encoder.3.0.pw_conv.0.weight", | |
| "encoder.conv_block_1.pw_conv.bias": "_model.encoder.3.0.pw_conv.0.bias", | |
| "encoder.conv_block_1.proj.weight": "_model.encoder.3.0.proj.weight", | |
| "encoder.conv_block_1.proj.bias": "_model.encoder.3.0.proj.bias", | |
| "encoder.conv_block_1.conv.weight": "_model.encoder.4.weight", | |
| "encoder.conv_block_1.conv.bias": "_model.encoder.4.bias", | |
| "encoder.conv_block_1.batch_norm.weight": "_model.encoder.5.weight", | |
| "encoder.conv_block_1.batch_norm.bias": "_model.encoder.5.bias", | |
| "encoder.conv_block_1.batch_norm.running_mean": "_model.encoder.5.running_mean", | |
| "encoder.conv_block_1.batch_norm.running_var": "_model.encoder.5.running_var", | |
| "encoder.conv_block_1.batch_norm.num_batches_tracked": "_model.encoder.5.num_batches_tracked", | |
| "encoder.conv_block_2.dw_conv.weight": "_model.encoder.7.0.dw_conv.0.weight", | |
| "encoder.conv_block_2.dw_conv.bias": "_model.encoder.7.0.dw_conv.0.bias", | |
| "encoder.conv_block_2.pw_conv.weight": "_model.encoder.7.0.pw_conv.0.weight", | |
| "encoder.conv_block_2.pw_conv.bias": "_model.encoder.7.0.pw_conv.0.bias", | |
| "encoder.conv_block_2.conv.weight": "_model.encoder.8.weight", | |
| "encoder.conv_block_2.conv.bias": "_model.encoder.8.bias", | |
| "encoder.conv_block_2.batch_norm.weight": "_model.encoder.9.weight", | |
| "encoder.conv_block_2.batch_norm.bias": "_model.encoder.9.bias", | |
| "encoder.conv_block_2.batch_norm.running_mean": "_model.encoder.9.running_mean", | |
| "encoder.conv_block_2.batch_norm.running_var": "_model.encoder.9.running_var", | |
| "encoder.conv_block_2.batch_norm.num_batches_tracked": "_model.encoder.9.num_batches_tracked", | |
| "encoder.conv_block_3.dw_conv.weight": "_model.encoder.11.0.dw_conv.0.weight", | |
| "encoder.conv_block_3.dw_conv.bias": "_model.encoder.11.0.dw_conv.0.bias", | |
| "encoder.conv_block_3.pw_conv.weight": "_model.encoder.11.0.pw_conv.0.weight", | |
| "encoder.conv_block_3.pw_conv.bias": "_model.encoder.11.0.pw_conv.0.bias", | |
| "encoder.conv_block_3.proj.weight": "_model.encoder.11.0.proj.weight", | |
| "encoder.conv_block_3.proj.bias": "_model.encoder.11.0.proj.bias", | |
| "encoder.conv_block_3.conv.weight": "_model.encoder.12.weight", | |
| "encoder.conv_block_3.conv.bias": "_model.encoder.12.bias", | |
| "encoder.conv_block_3.batch_norm.weight": "_model.encoder.13.weight", | |
| "encoder.conv_block_3.batch_norm.bias": "_model.encoder.13.bias", | |
| "encoder.conv_block_3.batch_norm.running_mean": "_model.encoder.13.running_mean", | |
| "encoder.conv_block_3.batch_norm.running_var": "_model.encoder.13.running_var", | |
| "encoder.conv_block_3.batch_norm.num_batches_tracked": "_model.encoder.13.num_batches_tracked", | |
| "decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih_l0", | |
| "decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh_l0", | |
| "decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih_l0", | |
| "decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh_l0", | |
| "decoder.rnn.weight_ih_l1": "_model.decoder.rnn.weight_ih_l1", | |
| "decoder.rnn.weight_hh_l1": "_model.decoder.rnn.weight_hh_l1", | |
| "decoder.rnn.bias_ih_l1": "_model.decoder.rnn.bias_ih_l1", | |
| "decoder.rnn.bias_hh_l1": "_model.decoder.rnn.bias_hh_l1", | |
| "decoder.conv1d.weight": "_model.decoder.decoder.1.weight", | |
| "decoder.conv1d.bias": "_model.decoder.decoder.1.bias", | |
| } | |
| sampling_rate = 16000 | |
| new_model = VadModel(sampling_rate) | |
| new_model = new_model.eval() | |
| if sampling_rate == 8000: | |
| model_mapping = { | |
| key: value.replace("_model", "_model_8k") | |
| for key, value in model_mapping.items() | |
| } | |
| model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad:v4.0stable", | |
| model="silero_vad", | |
| force_reload=True, | |
| ) | |
| model = model.eval() | |
| new_model.load_state_dict( | |
| {key: model.state_dict()[model_mapping[key]] for key in model_mapping.keys()} | |
| ) | |
| torch.onnx.export( | |
| new_model.encoder, | |
| {"x": torch.randn(10, 512, dtype=torch.float32)}, | |
| f=f"encoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input"], | |
| dynamic_axes={"input": {0: "batch", 1: "num_samples"}}, | |
| ) | |
| torch.onnx.export( | |
| new_model.decoder, | |
| dict( | |
| input=torch.randn(1, 64, 2, dtype=torch.float32), | |
| h=torch.rand((2, 1, 64), dtype=torch.float32), | |
| c=torch.rand((2, 1, 64), dtype=torch.float32), | |
| ), | |
| f=f"decoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input", "h", "c"], | |
| dynamic_axes={ | |
| "input": {0: "batch", 2: "num_samples"}, | |
| "h": {1: "batch"}, | |
| "c": {1: "batch"}, | |
| }, | |
| ) |
There are also some tiny differences between the v4 and v5 implementations regarding the STFT class.
Fixed the error in L174, the difference in STFT is based on the original implementation parameters, this is not actual STFT to be accurate but I kept the naming convention of the original models
Thanks!
I think there are still some glitches in V4: single_step_forward does not seem to be defined and also self.num_samples
Fixed them, these methods were directly ported from V5, I didn't test them well because they are not used, I only used the forward method
Thanks a lot. I dont think v4 has the context concatenated to input, does it? Is the forward method and how you export the model at the end with input size of (512+64) are still correct?
It does concatenate the context, you can refer to the onnx example on the original repo and in faster-whisper before v5 implementation
I'm really not sure. All I see is that inputs are the chunks and the states are updated through h and c
https://github.com/snakers4/silero-vad/blob/915dd3d639b8333a52e001af095f87c5b7f1e0ac/utils_vad.py#L86
(I'm not talking about the internal concatenation of x_feature and its normalized version btw)
I didn't check the decoder part yet but the VadModel class should be sth like the following:
class VadModel(nn.Module):
def __init__(self):
super(VadModel, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def get_initial_states(self, batch_size: int, device: torch.device):
h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
return h, c
def forward(self, x):
bsz = x.size(0)
h, c = self.get_initial_states(bsz, device=x.device)
x = self.encoder(x).reshape(bsz, -1, 64, 2)
decoder_outputs = []
for window in x.unbind(1):
out, h, c = self.decoder(window, h, c)
decoder_outputs.append(out)
out = torch.stack(decoder_outputs, dim=1).squeeze(-1)
return out.mean(dim=-1)and then the outputs of the encoder can be validated as follows:
# sample input
x = torch.randn(1, 512, dtype=torch.float32)
# validate encoder outputs
jit_out = model.forward(x)
x0 = model.feature_extractor(x)
norm = model.adaptive_normalization(x0)
x1 = torch.cat([x0, norm], 1)
x2 = model.first_layer(x1)
enc = model.encoder(x2)
vad_enc = new_model.encoder(x)
assert torch.allclose(enc, vad_enc, atol=1e-5)I revised the implementation again and fixed your notes, also supported variable num_samples and 8KHz
Many thanks! I see that around line 174, there are some more statements after the
returnwhere the variablesxare not undefined.