-
-
Save MahmoudAshraf97/8a089e0a361ebd15978b3b04f866bf62 to your computer and use it in GitHub Desktop.
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class STFT(nn.Module): | |
| def __init__(self, filter_length, hop_length): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.padding = nn.ReflectionPad1d((filter_length - hop_length) // 2) | |
| self.register_buffer( | |
| "forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length]) | |
| ) | |
| def forward(self, input_data): | |
| input_data = self.padding(input_data).unsqueeze(1) | |
| forward_transform = F.conv1d( | |
| input_data, self.forward_basis_buffer, stride=self.hop_length | |
| ) | |
| cutoff = self.filter_length // 2 + 1 | |
| real_part = forward_transform[:, :cutoff] | |
| imag_part = forward_transform[:, cutoff:] | |
| magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2)) | |
| return magnitude | |
| class AdaptiveAudioNormalization(nn.Module): | |
| def __init__(self): | |
| super(AdaptiveAudioNormalization, self).__init__() | |
| self.padding = nn.ReflectionPad1d(3) | |
| self.filter = self.gaussian_kernel(7, 1.5).reshape(1, 1, -1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| log_spec = torch.log1p(x * 1048576) | |
| if log_spec.ndim == 2: | |
| log_spec = log_spec.unsqueeze(0) | |
| mean_log_spec = log_spec.mean(dim=1, keepdim=True) | |
| mean_log_spec = self.padding(mean_log_spec) | |
| mean_log_spec = torch.conv1d(mean_log_spec, self.filter) | |
| mean_log_spec = mean_log_spec.mean(dim=-1, keepdim=True) | |
| normalized = log_spec - mean_log_spec | |
| return normalized | |
| @staticmethod | |
| def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: | |
| kernel_range = torch.arange(-(size // 2), (size // 2) + 1) | |
| kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2) | |
| kernel = kernel / kernel.sum() | |
| return kernel | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 129, | |
| out_channels: int = 16, | |
| stride: int = 2, | |
| has_out_proj: bool = True, | |
| ): | |
| super(ConvBlock, self).__init__() | |
| self.relu = nn.ReLU() | |
| self.dw_conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=5, | |
| padding=2, | |
| groups=in_channels, | |
| ) | |
| self.pw_conv = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| if has_out_proj: | |
| self.proj = nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| else: | |
| self.proj = nn.Identity() | |
| self.dropout = nn.Dropout(p=0.15) | |
| self.conv = nn.Conv1d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| ) | |
| self.batch_norm = nn.BatchNorm1d(out_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.dw_conv(x) | |
| x = self.relu(x) | |
| x = self.pw_conv(x) | |
| x += self.proj(residual) | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| x = self.conv(x) | |
| x = self.batch_norm(x) | |
| x = self.relu(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(Encoder, self).__init__() | |
| assert sampling_rate in [ | |
| 8000, | |
| 16000, | |
| ], "Supported sampling rates are 8000 and 16000" | |
| self.feature_extractor = STFT(256, 64) | |
| self.adaptive_normalization = AdaptiveAudioNormalization() | |
| self.conv_block_0 = ConvBlock( | |
| in_channels=258, out_channels=16, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_1 = ConvBlock( | |
| in_channels=16, out_channels=32, stride=2, has_out_proj=True | |
| ) | |
| self.conv_block_2 = ConvBlock( | |
| in_channels=32, | |
| out_channels=32, | |
| stride=2 if sampling_rate == 16000 else 1, | |
| has_out_proj=False, | |
| ) | |
| self.conv_block_3 = ConvBlock( | |
| in_channels=32, out_channels=64, stride=1, has_out_proj=True | |
| ) | |
| def forward(self, x): | |
| x_feature = self.feature_extractor(x) | |
| x_norm = self.adaptive_normalization(x_feature) | |
| x = torch.cat([x_feature, x_norm], 1) | |
| x = self.conv_block_0(x) | |
| x = self.conv_block_1(x) | |
| x = self.conv_block_2(x) | |
| x = self.conv_block_3(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| self.rnn = torch.nn.LSTM( | |
| input_size=64, hidden_size=64, num_layers=2, batch_first=True, dropout=0.1 | |
| ) | |
| self.dropout = nn.Dropout(p=0.1) | |
| self.relu = nn.ReLU() | |
| self.conv1d = torch.nn.Conv1d(in_channels=64, out_channels=1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input, h, c): | |
| input, (h, c) = self.rnn(input.permute([0, 2, 1]), (h, c)) | |
| input = self.dropout(input.permute([0, 2, 1])) | |
| input = self.relu(input) | |
| input = self.conv1d(input) | |
| input = self.sigmoid(input) | |
| return input, h, c | |
| class VadModel(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(VadModel, self).__init__() | |
| self.encoder = Encoder(sampling_rate) | |
| self.decoder = Decoder() | |
| @staticmethod | |
| def get_initial_states(batch_size: int, device: torch.device): | |
| h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device) | |
| return h, c | |
| def single_step_forward(self, x, h, c): | |
| x = self.encoder(x) | |
| out, h, c = self.decoder(x, h, c) | |
| out = torch.mean(out, -1) | |
| return out, h, c | |
| def forward(self, x, num_samples=512): | |
| assert ( | |
| x.ndim == 2 | |
| ), "Input should be a 2D tensor with size (batch_size, num_samples)" | |
| num_audio = x.size(0) | |
| h, c = self.get_initial_states(num_audio, device=x.device) | |
| x = x.reshape(-1, num_samples) | |
| x = self.encoder(x) | |
| x = x.reshape(num_audio, -1, 64, x.size(-1)) | |
| decoder_outputs = [] | |
| for window in x.unbind(1): | |
| out, h, c = self.decoder(window, h, c) | |
| decoder_outputs.append(out) | |
| out = torch.stack(decoder_outputs, dim=1).squeeze(-1) | |
| return out.mean(dim=-1) | |
| model_mapping = { | |
| "encoder.feature_extractor.forward_basis_buffer": "_model.feature_extractor.forward_basis_buffer", | |
| "encoder.conv_block_0.dw_conv.weight": "_model.first_layer.0.dw_conv.0.weight", | |
| "encoder.conv_block_0.dw_conv.bias": "_model.first_layer.0.dw_conv.0.bias", | |
| "encoder.conv_block_0.pw_conv.weight": "_model.first_layer.0.pw_conv.0.weight", | |
| "encoder.conv_block_0.pw_conv.bias": "_model.first_layer.0.pw_conv.0.bias", | |
| "encoder.conv_block_0.proj.weight": "_model.first_layer.0.proj.weight", | |
| "encoder.conv_block_0.proj.bias": "_model.first_layer.0.proj.bias", | |
| "encoder.conv_block_0.conv.weight": "_model.encoder.0.weight", | |
| "encoder.conv_block_0.conv.bias": "_model.encoder.0.bias", | |
| "encoder.conv_block_0.batch_norm.weight": "_model.encoder.1.weight", | |
| "encoder.conv_block_0.batch_norm.bias": "_model.encoder.1.bias", | |
| "encoder.conv_block_0.batch_norm.running_mean": "_model.encoder.1.running_mean", | |
| "encoder.conv_block_0.batch_norm.running_var": "_model.encoder.1.running_var", | |
| "encoder.conv_block_0.batch_norm.num_batches_tracked": "_model.encoder.1.num_batches_tracked", | |
| "encoder.conv_block_1.dw_conv.weight": "_model.encoder.3.0.dw_conv.0.weight", | |
| "encoder.conv_block_1.dw_conv.bias": "_model.encoder.3.0.dw_conv.0.bias", | |
| "encoder.conv_block_1.pw_conv.weight": "_model.encoder.3.0.pw_conv.0.weight", | |
| "encoder.conv_block_1.pw_conv.bias": "_model.encoder.3.0.pw_conv.0.bias", | |
| "encoder.conv_block_1.proj.weight": "_model.encoder.3.0.proj.weight", | |
| "encoder.conv_block_1.proj.bias": "_model.encoder.3.0.proj.bias", | |
| "encoder.conv_block_1.conv.weight": "_model.encoder.4.weight", | |
| "encoder.conv_block_1.conv.bias": "_model.encoder.4.bias", | |
| "encoder.conv_block_1.batch_norm.weight": "_model.encoder.5.weight", | |
| "encoder.conv_block_1.batch_norm.bias": "_model.encoder.5.bias", | |
| "encoder.conv_block_1.batch_norm.running_mean": "_model.encoder.5.running_mean", | |
| "encoder.conv_block_1.batch_norm.running_var": "_model.encoder.5.running_var", | |
| "encoder.conv_block_1.batch_norm.num_batches_tracked": "_model.encoder.5.num_batches_tracked", | |
| "encoder.conv_block_2.dw_conv.weight": "_model.encoder.7.0.dw_conv.0.weight", | |
| "encoder.conv_block_2.dw_conv.bias": "_model.encoder.7.0.dw_conv.0.bias", | |
| "encoder.conv_block_2.pw_conv.weight": "_model.encoder.7.0.pw_conv.0.weight", | |
| "encoder.conv_block_2.pw_conv.bias": "_model.encoder.7.0.pw_conv.0.bias", | |
| "encoder.conv_block_2.conv.weight": "_model.encoder.8.weight", | |
| "encoder.conv_block_2.conv.bias": "_model.encoder.8.bias", | |
| "encoder.conv_block_2.batch_norm.weight": "_model.encoder.9.weight", | |
| "encoder.conv_block_2.batch_norm.bias": "_model.encoder.9.bias", | |
| "encoder.conv_block_2.batch_norm.running_mean": "_model.encoder.9.running_mean", | |
| "encoder.conv_block_2.batch_norm.running_var": "_model.encoder.9.running_var", | |
| "encoder.conv_block_2.batch_norm.num_batches_tracked": "_model.encoder.9.num_batches_tracked", | |
| "encoder.conv_block_3.dw_conv.weight": "_model.encoder.11.0.dw_conv.0.weight", | |
| "encoder.conv_block_3.dw_conv.bias": "_model.encoder.11.0.dw_conv.0.bias", | |
| "encoder.conv_block_3.pw_conv.weight": "_model.encoder.11.0.pw_conv.0.weight", | |
| "encoder.conv_block_3.pw_conv.bias": "_model.encoder.11.0.pw_conv.0.bias", | |
| "encoder.conv_block_3.proj.weight": "_model.encoder.11.0.proj.weight", | |
| "encoder.conv_block_3.proj.bias": "_model.encoder.11.0.proj.bias", | |
| "encoder.conv_block_3.conv.weight": "_model.encoder.12.weight", | |
| "encoder.conv_block_3.conv.bias": "_model.encoder.12.bias", | |
| "encoder.conv_block_3.batch_norm.weight": "_model.encoder.13.weight", | |
| "encoder.conv_block_3.batch_norm.bias": "_model.encoder.13.bias", | |
| "encoder.conv_block_3.batch_norm.running_mean": "_model.encoder.13.running_mean", | |
| "encoder.conv_block_3.batch_norm.running_var": "_model.encoder.13.running_var", | |
| "encoder.conv_block_3.batch_norm.num_batches_tracked": "_model.encoder.13.num_batches_tracked", | |
| "decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih_l0", | |
| "decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh_l0", | |
| "decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih_l0", | |
| "decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh_l0", | |
| "decoder.rnn.weight_ih_l1": "_model.decoder.rnn.weight_ih_l1", | |
| "decoder.rnn.weight_hh_l1": "_model.decoder.rnn.weight_hh_l1", | |
| "decoder.rnn.bias_ih_l1": "_model.decoder.rnn.bias_ih_l1", | |
| "decoder.rnn.bias_hh_l1": "_model.decoder.rnn.bias_hh_l1", | |
| "decoder.conv1d.weight": "_model.decoder.decoder.1.weight", | |
| "decoder.conv1d.bias": "_model.decoder.decoder.1.bias", | |
| } | |
| sampling_rate = 16000 | |
| new_model = VadModel(sampling_rate) | |
| new_model = new_model.eval() | |
| if sampling_rate == 8000: | |
| model_mapping = { | |
| key: value.replace("_model", "_model_8k") | |
| for key, value in model_mapping.items() | |
| } | |
| model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad:v4.0stable", | |
| model="silero_vad", | |
| force_reload=True, | |
| ) | |
| model = model.eval() | |
| new_model.load_state_dict( | |
| {key: model.state_dict()[model_mapping[key]] for key in model_mapping.keys()} | |
| ) | |
| torch.onnx.export( | |
| new_model.encoder, | |
| {"x": torch.randn(10, 512, dtype=torch.float32)}, | |
| f=f"encoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input"], | |
| dynamic_axes={"input": {0: "batch", 1: "num_samples"}}, | |
| ) | |
| torch.onnx.export( | |
| new_model.decoder, | |
| dict( | |
| input=torch.randn(1, 64, 2, dtype=torch.float32), | |
| h=torch.rand((2, 1, 64), dtype=torch.float32), | |
| c=torch.rand((2, 1, 64), dtype=torch.float32), | |
| ), | |
| f=f"decoder_v4_{sampling_rate // 1000}khz.onnx", | |
| input_names=["input", "h", "c"], | |
| dynamic_axes={ | |
| "input": {0: "batch", 2: "num_samples"}, | |
| "h": {1: "batch"}, | |
| "c": {1: "batch"}, | |
| }, | |
| ) |
I think there are still some glitches in V4: single_step_forward does not seem to be defined and also self.num_samples
Fixed them, these methods were directly ported from V5, I didn't test them well because they are not used, I only used the forward method
Thanks a lot. I dont think v4 has the context concatenated to input, does it? Is the forward method and how you export the model at the end with input size of (512+64) are still correct?
It does concatenate the context, you can refer to the onnx example on the original repo and in faster-whisper before v5 implementation
I'm really not sure. All I see is that inputs are the chunks and the states are updated through h and c
https://github.com/snakers4/silero-vad/blob/915dd3d639b8333a52e001af095f87c5b7f1e0ac/utils_vad.py#L86
(I'm not talking about the internal concatenation of x_feature and its normalized version btw)
I didn't check the decoder part yet but the VadModel class should be sth like the following:
class VadModel(nn.Module):
def __init__(self):
super(VadModel, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def get_initial_states(self, batch_size: int, device: torch.device):
h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
return h, c
def forward(self, x):
bsz = x.size(0)
h, c = self.get_initial_states(bsz, device=x.device)
x = self.encoder(x).reshape(bsz, -1, 64, 2)
decoder_outputs = []
for window in x.unbind(1):
out, h, c = self.decoder(window, h, c)
decoder_outputs.append(out)
out = torch.stack(decoder_outputs, dim=1).squeeze(-1)
return out.mean(dim=-1)and then the outputs of the encoder can be validated as follows:
# sample input
x = torch.randn(1, 512, dtype=torch.float32)
# validate encoder outputs
jit_out = model.forward(x)
x0 = model.feature_extractor(x)
norm = model.adaptive_normalization(x0)
x1 = torch.cat([x0, norm], 1)
x2 = model.first_layer(x1)
enc = model.encoder(x2)
vad_enc = new_model.encoder(x)
assert torch.allclose(enc, vad_enc, atol=1e-5)I revised the implementation again and fixed your notes, also supported variable num_samples and 8KHz
Thanks!