Skip to content

Instantly share code, notes, and snippets.

@MahmoudAshraf97
Last active August 15, 2024 13:40
Show Gist options
  • Select an option

  • Save MahmoudAshraf97/8a089e0a361ebd15978b3b04f866bf62 to your computer and use it in GitHub Desktop.

Select an option

Save MahmoudAshraf97/8a089e0a361ebd15978b3b04f866bf62 to your computer and use it in GitHub Desktop.
Reference Implementation of Silero V4 VAD model
import torch
import torch.nn as nn
import torch.nn.functional as F
class STFT(nn.Module):
def __init__(self, filter_length, hop_length):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.padding = nn.ReflectionPad1d((filter_length - hop_length) // 2)
self.register_buffer(
"forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length])
)
def forward(self, input_data):
input_data = self.padding(input_data).unsqueeze(1)
forward_transform = F.conv1d(
input_data, self.forward_basis_buffer, stride=self.hop_length
)
cutoff = self.filter_length // 2 + 1
real_part = forward_transform[:, :cutoff]
imag_part = forward_transform[:, cutoff:]
magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2))
return magnitude
class AdaptiveAudioNormalization(nn.Module):
def __init__(self):
super(AdaptiveAudioNormalization, self).__init__()
self.padding = nn.ReflectionPad1d(3)
self.filter = self.gaussian_kernel(7, 1.5).reshape(1, 1, -1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
log_spec = torch.log1p(x * 1048576)
if log_spec.ndim == 2:
log_spec = log_spec.unsqueeze(0)
mean_log_spec = log_spec.mean(dim=1, keepdim=True)
mean_log_spec = self.padding(mean_log_spec)
mean_log_spec = torch.conv1d(mean_log_spec, self.filter)
mean_log_spec = mean_log_spec.mean(dim=-1, keepdim=True)
normalized = log_spec - mean_log_spec
return normalized
@staticmethod
def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
kernel_range = torch.arange(-(size // 2), (size // 2) + 1)
kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2)
kernel = kernel / kernel.sum()
return kernel
class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int = 129,
out_channels: int = 16,
stride: int = 2,
has_out_proj: bool = True,
):
super(ConvBlock, self).__init__()
self.relu = nn.ReLU()
self.dw_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=5,
padding=2,
groups=in_channels,
)
self.pw_conv = nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1
)
if has_out_proj:
self.proj = nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1
)
else:
self.proj = nn.Identity()
self.dropout = nn.Dropout(p=0.15)
self.conv = nn.Conv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
)
self.batch_norm = nn.BatchNorm1d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.dw_conv(x)
x = self.relu(x)
x = self.pw_conv(x)
x += self.proj(residual)
x = self.relu(x)
x = self.dropout(x)
x = self.conv(x)
x = self.batch_norm(x)
x = self.relu(x)
return x
class Encoder(nn.Module):
def __init__(self, sampling_rate=16000):
super(Encoder, self).__init__()
assert sampling_rate in [
8000,
16000,
], "Supported sampling rates are 8000 and 16000"
self.feature_extractor = STFT(256, 64)
self.adaptive_normalization = AdaptiveAudioNormalization()
self.conv_block_0 = ConvBlock(
in_channels=258, out_channels=16, stride=2, has_out_proj=True
)
self.conv_block_1 = ConvBlock(
in_channels=16, out_channels=32, stride=2, has_out_proj=True
)
self.conv_block_2 = ConvBlock(
in_channels=32,
out_channels=32,
stride=2 if sampling_rate == 16000 else 1,
has_out_proj=False,
)
self.conv_block_3 = ConvBlock(
in_channels=32, out_channels=64, stride=1, has_out_proj=True
)
def forward(self, x):
x_feature = self.feature_extractor(x)
x_norm = self.adaptive_normalization(x_feature)
x = torch.cat([x_feature, x_norm], 1)
x = self.conv_block_0(x)
x = self.conv_block_1(x)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
return x
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.rnn = torch.nn.LSTM(
input_size=64, hidden_size=64, num_layers=2, batch_first=True, dropout=0.1
)
self.dropout = nn.Dropout(p=0.1)
self.relu = nn.ReLU()
self.conv1d = torch.nn.Conv1d(in_channels=64, out_channels=1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, input, h, c):
input, (h, c) = self.rnn(input.permute([0, 2, 1]), (h, c))
input = self.dropout(input.permute([0, 2, 1]))
input = self.relu(input)
input = self.conv1d(input)
input = self.sigmoid(input)
return input, h, c
class VadModel(nn.Module):
def __init__(self, sampling_rate=16000):
super(VadModel, self).__init__()
self.encoder = Encoder(sampling_rate)
self.decoder = Decoder()
@staticmethod
def get_initial_states(batch_size: int, device: torch.device):
h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
return h, c
def single_step_forward(self, x, h, c):
x = self.encoder(x)
out, h, c = self.decoder(x, h, c)
out = torch.mean(out, -1)
return out, h, c
def forward(self, x, num_samples=512):
assert (
x.ndim == 2
), "Input should be a 2D tensor with size (batch_size, num_samples)"
num_audio = x.size(0)
h, c = self.get_initial_states(num_audio, device=x.device)
x = x.reshape(-1, num_samples)
x = self.encoder(x)
x = x.reshape(num_audio, -1, 64, x.size(-1))
decoder_outputs = []
for window in x.unbind(1):
out, h, c = self.decoder(window, h, c)
decoder_outputs.append(out)
out = torch.stack(decoder_outputs, dim=1).squeeze(-1)
return out.mean(dim=-1)
model_mapping = {
"encoder.feature_extractor.forward_basis_buffer": "_model.feature_extractor.forward_basis_buffer",
"encoder.conv_block_0.dw_conv.weight": "_model.first_layer.0.dw_conv.0.weight",
"encoder.conv_block_0.dw_conv.bias": "_model.first_layer.0.dw_conv.0.bias",
"encoder.conv_block_0.pw_conv.weight": "_model.first_layer.0.pw_conv.0.weight",
"encoder.conv_block_0.pw_conv.bias": "_model.first_layer.0.pw_conv.0.bias",
"encoder.conv_block_0.proj.weight": "_model.first_layer.0.proj.weight",
"encoder.conv_block_0.proj.bias": "_model.first_layer.0.proj.bias",
"encoder.conv_block_0.conv.weight": "_model.encoder.0.weight",
"encoder.conv_block_0.conv.bias": "_model.encoder.0.bias",
"encoder.conv_block_0.batch_norm.weight": "_model.encoder.1.weight",
"encoder.conv_block_0.batch_norm.bias": "_model.encoder.1.bias",
"encoder.conv_block_0.batch_norm.running_mean": "_model.encoder.1.running_mean",
"encoder.conv_block_0.batch_norm.running_var": "_model.encoder.1.running_var",
"encoder.conv_block_0.batch_norm.num_batches_tracked": "_model.encoder.1.num_batches_tracked",
"encoder.conv_block_1.dw_conv.weight": "_model.encoder.3.0.dw_conv.0.weight",
"encoder.conv_block_1.dw_conv.bias": "_model.encoder.3.0.dw_conv.0.bias",
"encoder.conv_block_1.pw_conv.weight": "_model.encoder.3.0.pw_conv.0.weight",
"encoder.conv_block_1.pw_conv.bias": "_model.encoder.3.0.pw_conv.0.bias",
"encoder.conv_block_1.proj.weight": "_model.encoder.3.0.proj.weight",
"encoder.conv_block_1.proj.bias": "_model.encoder.3.0.proj.bias",
"encoder.conv_block_1.conv.weight": "_model.encoder.4.weight",
"encoder.conv_block_1.conv.bias": "_model.encoder.4.bias",
"encoder.conv_block_1.batch_norm.weight": "_model.encoder.5.weight",
"encoder.conv_block_1.batch_norm.bias": "_model.encoder.5.bias",
"encoder.conv_block_1.batch_norm.running_mean": "_model.encoder.5.running_mean",
"encoder.conv_block_1.batch_norm.running_var": "_model.encoder.5.running_var",
"encoder.conv_block_1.batch_norm.num_batches_tracked": "_model.encoder.5.num_batches_tracked",
"encoder.conv_block_2.dw_conv.weight": "_model.encoder.7.0.dw_conv.0.weight",
"encoder.conv_block_2.dw_conv.bias": "_model.encoder.7.0.dw_conv.0.bias",
"encoder.conv_block_2.pw_conv.weight": "_model.encoder.7.0.pw_conv.0.weight",
"encoder.conv_block_2.pw_conv.bias": "_model.encoder.7.0.pw_conv.0.bias",
"encoder.conv_block_2.conv.weight": "_model.encoder.8.weight",
"encoder.conv_block_2.conv.bias": "_model.encoder.8.bias",
"encoder.conv_block_2.batch_norm.weight": "_model.encoder.9.weight",
"encoder.conv_block_2.batch_norm.bias": "_model.encoder.9.bias",
"encoder.conv_block_2.batch_norm.running_mean": "_model.encoder.9.running_mean",
"encoder.conv_block_2.batch_norm.running_var": "_model.encoder.9.running_var",
"encoder.conv_block_2.batch_norm.num_batches_tracked": "_model.encoder.9.num_batches_tracked",
"encoder.conv_block_3.dw_conv.weight": "_model.encoder.11.0.dw_conv.0.weight",
"encoder.conv_block_3.dw_conv.bias": "_model.encoder.11.0.dw_conv.0.bias",
"encoder.conv_block_3.pw_conv.weight": "_model.encoder.11.0.pw_conv.0.weight",
"encoder.conv_block_3.pw_conv.bias": "_model.encoder.11.0.pw_conv.0.bias",
"encoder.conv_block_3.proj.weight": "_model.encoder.11.0.proj.weight",
"encoder.conv_block_3.proj.bias": "_model.encoder.11.0.proj.bias",
"encoder.conv_block_3.conv.weight": "_model.encoder.12.weight",
"encoder.conv_block_3.conv.bias": "_model.encoder.12.bias",
"encoder.conv_block_3.batch_norm.weight": "_model.encoder.13.weight",
"encoder.conv_block_3.batch_norm.bias": "_model.encoder.13.bias",
"encoder.conv_block_3.batch_norm.running_mean": "_model.encoder.13.running_mean",
"encoder.conv_block_3.batch_norm.running_var": "_model.encoder.13.running_var",
"encoder.conv_block_3.batch_norm.num_batches_tracked": "_model.encoder.13.num_batches_tracked",
"decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih_l0",
"decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh_l0",
"decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih_l0",
"decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh_l0",
"decoder.rnn.weight_ih_l1": "_model.decoder.rnn.weight_ih_l1",
"decoder.rnn.weight_hh_l1": "_model.decoder.rnn.weight_hh_l1",
"decoder.rnn.bias_ih_l1": "_model.decoder.rnn.bias_ih_l1",
"decoder.rnn.bias_hh_l1": "_model.decoder.rnn.bias_hh_l1",
"decoder.conv1d.weight": "_model.decoder.decoder.1.weight",
"decoder.conv1d.bias": "_model.decoder.decoder.1.bias",
}
sampling_rate = 16000
new_model = VadModel(sampling_rate)
new_model = new_model.eval()
if sampling_rate == 8000:
model_mapping = {
key: value.replace("_model", "_model_8k")
for key, value in model_mapping.items()
}
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad:v4.0stable",
model="silero_vad",
force_reload=True,
)
model = model.eval()
new_model.load_state_dict(
{key: model.state_dict()[model_mapping[key]] for key in model_mapping.keys()}
)
torch.onnx.export(
new_model.encoder,
{"x": torch.randn(10, 512, dtype=torch.float32)},
f=f"encoder_v4_{sampling_rate // 1000}khz.onnx",
input_names=["input"],
dynamic_axes={"input": {0: "batch", 1: "num_samples"}},
)
torch.onnx.export(
new_model.decoder,
dict(
input=torch.randn(1, 64, 2, dtype=torch.float32),
h=torch.rand((2, 1, 64), dtype=torch.float32),
c=torch.rand((2, 1, 64), dtype=torch.float32),
),
f=f"decoder_v4_{sampling_rate // 1000}khz.onnx",
input_names=["input", "h", "c"],
dynamic_axes={
"input": {0: "batch", 2: "num_samples"},
"h": {1: "batch"},
"c": {1: "batch"},
},
)
@MahmoudAshraf97
Copy link
Author

Fixed the error in L174, the difference in STFT is based on the original implementation parameters, this is not actual STFT to be accurate but I kept the naming convention of the original models

@ozancaglayan
Copy link

Thanks!

@ozancaglayan
Copy link

I think there are still some glitches in V4: single_step_forward does not seem to be defined and also self.num_samples

@MahmoudAshraf97
Copy link
Author

Fixed them, these methods were directly ported from V5, I didn't test them well because they are not used, I only used the forward method

@ozancaglayan
Copy link

Thanks a lot. I dont think v4 has the context concatenated to input, does it? Is the forward method and how you export the model at the end with input size of (512+64) are still correct?

@MahmoudAshraf97
Copy link
Author

It does concatenate the context, you can refer to the onnx example on the original repo and in faster-whisper before v5 implementation

@ozancaglayan
Copy link

ozancaglayan commented Aug 15, 2024

I'm really not sure. All I see is that inputs are the chunks and the states are updated through h and c
https://github.com/snakers4/silero-vad/blob/915dd3d639b8333a52e001af095f87c5b7f1e0ac/utils_vad.py#L86

(I'm not talking about the internal concatenation of x_feature and its normalized version btw)

@ozancaglayan
Copy link

ozancaglayan commented Aug 15, 2024

I didn't check the decoder part yet but the VadModel class should be sth like the following:

class VadModel(nn.Module):
    def __init__(self):
        super(VadModel, self).__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()

    def get_initial_states(self, batch_size: int, device: torch.device):
        h = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
        c = torch.zeros((2, batch_size, 64), dtype=torch.float32, device=device)
        return h, c

    def forward(self, x):
        bsz = x.size(0)
        h, c = self.get_initial_states(bsz, device=x.device)

        x = self.encoder(x).reshape(bsz, -1, 64, 2)

        decoder_outputs = []
        for window in x.unbind(1):
            out, h, c = self.decoder(window, h, c)
            decoder_outputs.append(out)

        out = torch.stack(decoder_outputs, dim=1).squeeze(-1)

        return out.mean(dim=-1)

and then the outputs of the encoder can be validated as follows:

    # sample input
    x = torch.randn(1, 512, dtype=torch.float32)

    # validate encoder outputs
    jit_out = model.forward(x)
    x0 = model.feature_extractor(x)
    norm = model.adaptive_normalization(x0)
    x1 = torch.cat([x0, norm], 1)
    x2 = model.first_layer(x1)
    enc = model.encoder(x2)

    vad_enc = new_model.encoder(x)
    assert torch.allclose(enc, vad_enc, atol=1e-5)

@MahmoudAshraf97
Copy link
Author

I revised the implementation again and fixed your notes, also supported variable num_samples and 8KHz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment