Created
October 29, 2024 11:52
-
-
Save cyai/4a6ce80c9d0dedc6fd1fbf770ab767c7 to your computer and use it in GitHub Desktop.
Model file for SenseVoice SRT model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from typing import Iterable, Optional | |
| from funasr.register import tables | |
| from funasr.models.ctc.ctc import CTC | |
| from funasr.utils.datadir_writer import DatadirWriter | |
| from funasr.models.paraformer.search import Hypothesis | |
| from funasr.train_utils.device_funcs import force_gatherable | |
| from funasr.losses.label_smoothing_loss import LabelSmoothingLoss | |
| from funasr.metrics.compute_acc import compute_accuracy, th_accuracy | |
| from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank | |
| class SinusoidalPositionEncoder(torch.nn.Module): | |
| """ """ | |
| def __int__(self, d_model=80, dropout_rate=0.1): | |
| pass | |
| def encode( | |
| self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32 | |
| ): | |
| batch_size = positions.size(0) | |
| positions = positions.type(dtype) | |
| device = positions.device | |
| log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / ( | |
| depth / 2 - 1 | |
| ) | |
| inv_timescales = torch.exp( | |
| torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment) | |
| ) | |
| inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) | |
| scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( | |
| inv_timescales, [1, 1, -1] | |
| ) | |
| encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) | |
| return encoding.type(dtype) | |
| def forward(self, x): | |
| batch_size, timesteps, input_dim = x.size() | |
| positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] | |
| position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) | |
| return x + position_encoding | |
| class PositionwiseFeedForward(torch.nn.Module): | |
| """Positionwise feed forward layer. | |
| Args: | |
| idim (int): Input dimenstion. | |
| hidden_units (int): The number of hidden units. | |
| dropout_rate (float): Dropout rate. | |
| """ | |
| def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): | |
| """Construct an PositionwiseFeedForward object.""" | |
| super(PositionwiseFeedForward, self).__init__() | |
| self.w_1 = torch.nn.Linear(idim, hidden_units) | |
| self.w_2 = torch.nn.Linear(hidden_units, idim) | |
| self.dropout = torch.nn.Dropout(dropout_rate) | |
| self.activation = activation | |
| def forward(self, x): | |
| """Forward function.""" | |
| return self.w_2(self.dropout(self.activation(self.w_1(x)))) | |
| class MultiHeadedAttentionSANM(nn.Module): | |
| """Multi-Head Attention layer. | |
| Args: | |
| n_head (int): The number of heads. | |
| n_feat (int): The number of features. | |
| dropout_rate (float): Dropout rate. | |
| """ | |
| def __init__( | |
| self, | |
| n_head, | |
| in_feat, | |
| n_feat, | |
| dropout_rate, | |
| kernel_size, | |
| sanm_shfit=0, | |
| lora_list=None, | |
| lora_rank=8, | |
| lora_alpha=16, | |
| lora_dropout=0.1, | |
| ): | |
| """Construct an MultiHeadedAttention object.""" | |
| super().__init__() | |
| assert n_feat % n_head == 0 | |
| # We assume d_v always equals d_k | |
| self.d_k = n_feat // n_head | |
| self.h = n_head | |
| # self.linear_q = nn.Linear(n_feat, n_feat) | |
| # self.linear_k = nn.Linear(n_feat, n_feat) | |
| # self.linear_v = nn.Linear(n_feat, n_feat) | |
| self.linear_out = nn.Linear(n_feat, n_feat) | |
| self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) | |
| self.attn = None | |
| self.dropout = nn.Dropout(p=dropout_rate) | |
| self.fsmn_block = nn.Conv1d( | |
| n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False | |
| ) | |
| # padding | |
| left_padding = (kernel_size - 1) // 2 | |
| if sanm_shfit > 0: | |
| left_padding = left_padding + sanm_shfit | |
| right_padding = kernel_size - 1 - left_padding | |
| self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) | |
| def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): | |
| b, t, d = inputs.size() | |
| if mask is not None: | |
| mask = torch.reshape(mask, (b, -1, 1)) | |
| if mask_shfit_chunk is not None: | |
| mask = mask * mask_shfit_chunk | |
| inputs = inputs * mask | |
| x = inputs.transpose(1, 2) | |
| x = self.pad_fn(x) | |
| x = self.fsmn_block(x) | |
| x = x.transpose(1, 2) | |
| x += inputs | |
| x = self.dropout(x) | |
| if mask is not None: | |
| x = x * mask | |
| return x | |
| def forward_qkv(self, x): | |
| """Transform query, key and value. | |
| Args: | |
| query (torch.Tensor): Query tensor (#batch, time1, size). | |
| key (torch.Tensor): Key tensor (#batch, time2, size). | |
| value (torch.Tensor): Value tensor (#batch, time2, size). | |
| Returns: | |
| torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |
| torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |
| torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |
| """ | |
| b, t, d = x.size() | |
| q_k_v = self.linear_q_k_v(x) | |
| q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) | |
| q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( | |
| 1, 2 | |
| ) # (batch, head, time1, d_k) | |
| k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( | |
| 1, 2 | |
| ) # (batch, head, time2, d_k) | |
| v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( | |
| 1, 2 | |
| ) # (batch, head, time2, d_k) | |
| return q_h, k_h, v_h, v | |
| def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): | |
| """Compute attention context vector. | |
| Args: | |
| value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |
| scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |
| mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |
| Returns: | |
| torch.Tensor: Transformed value (#batch, time1, d_model) | |
| weighted by the attention score (#batch, time1, time2). | |
| """ | |
| n_batch = value.size(0) | |
| if mask is not None: | |
| if mask_att_chunk_encoder is not None: | |
| mask = mask * mask_att_chunk_encoder | |
| mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |
| min_value = -float( | |
| "inf" | |
| ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) | |
| scores = scores.masked_fill(mask, min_value) | |
| attn = torch.softmax(scores, dim=-1).masked_fill( | |
| mask, 0.0 | |
| ) # (batch, head, time1, time2) | |
| else: | |
| attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) | |
| p_attn = self.dropout(attn) | |
| x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |
| x = ( | |
| x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) | |
| ) # (batch, time1, d_model) | |
| return self.linear_out(x) # (batch, time1, d_model) | |
| def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): | |
| """Compute scaled dot product attention. | |
| Args: | |
| query (torch.Tensor): Query tensor (#batch, time1, size). | |
| key (torch.Tensor): Key tensor (#batch, time2, size). | |
| value (torch.Tensor): Value tensor (#batch, time2, size). | |
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |
| (#batch, time1, time2). | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, time1, d_model). | |
| """ | |
| q_h, k_h, v_h, v = self.forward_qkv(x) | |
| fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) | |
| q_h = q_h * self.d_k ** (-0.5) | |
| scores = torch.matmul(q_h, k_h.transpose(-2, -1)) | |
| att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) | |
| return att_outs + fsmn_memory | |
| def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): | |
| """Compute scaled dot product attention. | |
| Args: | |
| query (torch.Tensor): Query tensor (#batch, time1, size). | |
| key (torch.Tensor): Key tensor (#batch, time2, size). | |
| value (torch.Tensor): Value tensor (#batch, time2, size). | |
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |
| (#batch, time1, time2). | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, time1, d_model). | |
| """ | |
| q_h, k_h, v_h, v = self.forward_qkv(x) | |
| if chunk_size is not None and look_back > 0 or look_back == -1: | |
| if cache is not None: | |
| k_h_stride = k_h[:, :, : -(chunk_size[2]), :] | |
| v_h_stride = v_h[:, :, : -(chunk_size[2]), :] | |
| k_h = torch.cat((cache["k"], k_h), dim=2) | |
| v_h = torch.cat((cache["v"], v_h), dim=2) | |
| cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) | |
| cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) | |
| if look_back != -1: | |
| cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :] | |
| cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :] | |
| else: | |
| cache_tmp = { | |
| "k": k_h[:, :, : -(chunk_size[2]), :], | |
| "v": v_h[:, :, : -(chunk_size[2]), :], | |
| } | |
| cache = cache_tmp | |
| fsmn_memory = self.forward_fsmn(v, None) | |
| q_h = q_h * self.d_k ** (-0.5) | |
| scores = torch.matmul(q_h, k_h.transpose(-2, -1)) | |
| att_outs = self.forward_attention(v_h, scores, None) | |
| return att_outs + fsmn_memory, cache | |
| class LayerNorm(nn.LayerNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input): | |
| output = F.layer_norm( | |
| input.float(), | |
| self.normalized_shape, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): | |
| if maxlen is None: | |
| maxlen = lengths.max() | |
| row_vector = torch.arange(0, maxlen, 1).to(lengths.device) | |
| matrix = torch.unsqueeze(lengths, dim=-1) | |
| mask = row_vector < matrix | |
| mask = mask.detach() | |
| return mask.type(dtype).to(device) if device is not None else mask.type(dtype) | |
| class EncoderLayerSANM(nn.Module): | |
| def __init__( | |
| self, | |
| in_size, | |
| size, | |
| self_attn, | |
| feed_forward, | |
| dropout_rate, | |
| normalize_before=True, | |
| concat_after=False, | |
| stochastic_depth_rate=0.0, | |
| ): | |
| """Construct an EncoderLayer object.""" | |
| super(EncoderLayerSANM, self).__init__() | |
| self.self_attn = self_attn | |
| self.feed_forward = feed_forward | |
| self.norm1 = LayerNorm(in_size) | |
| self.norm2 = LayerNorm(size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.in_size = in_size | |
| self.size = size | |
| self.normalize_before = normalize_before | |
| self.concat_after = concat_after | |
| if self.concat_after: | |
| self.concat_linear = nn.Linear(size + size, size) | |
| self.stochastic_depth_rate = stochastic_depth_rate | |
| self.dropout_rate = dropout_rate | |
| def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): | |
| """Compute encoded features. | |
| Args: | |
| x_input (torch.Tensor): Input tensor (#batch, time, size). | |
| mask (torch.Tensor): Mask tensor for the input (#batch, time). | |
| cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, time, size). | |
| torch.Tensor: Mask tensor (#batch, time). | |
| """ | |
| skip_layer = False | |
| # with stochastic depth, residual connection `x + f(x)` becomes | |
| # `x <- x + 1 / (1 - p) * f(x)` at training time. | |
| stoch_layer_coeff = 1.0 | |
| if self.training and self.stochastic_depth_rate > 0: | |
| skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |
| stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |
| if skip_layer: | |
| if cache is not None: | |
| x = torch.cat([cache, x], dim=1) | |
| return x, mask | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm1(x) | |
| if self.concat_after: | |
| x_concat = torch.cat( | |
| ( | |
| x, | |
| self.self_attn( | |
| x, | |
| mask, | |
| mask_shfit_chunk=mask_shfit_chunk, | |
| mask_att_chunk_encoder=mask_att_chunk_encoder, | |
| ), | |
| ), | |
| dim=-1, | |
| ) | |
| if self.in_size == self.size: | |
| x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |
| else: | |
| x = stoch_layer_coeff * self.concat_linear(x_concat) | |
| else: | |
| if self.in_size == self.size: | |
| x = residual + stoch_layer_coeff * self.dropout( | |
| self.self_attn( | |
| x, | |
| mask, | |
| mask_shfit_chunk=mask_shfit_chunk, | |
| mask_att_chunk_encoder=mask_att_chunk_encoder, | |
| ) | |
| ) | |
| else: | |
| x = stoch_layer_coeff * self.dropout( | |
| self.self_attn( | |
| x, | |
| mask, | |
| mask_shfit_chunk=mask_shfit_chunk, | |
| mask_att_chunk_encoder=mask_att_chunk_encoder, | |
| ) | |
| ) | |
| if not self.normalize_before: | |
| x = self.norm1(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm2(x) | |
| x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |
| if not self.normalize_before: | |
| x = self.norm2(x) | |
| return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder | |
| def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): | |
| """Compute encoded features. | |
| Args: | |
| x_input (torch.Tensor): Input tensor (#batch, time, size). | |
| mask (torch.Tensor): Mask tensor for the input (#batch, time). | |
| cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, time, size). | |
| torch.Tensor: Mask tensor (#batch, time). | |
| """ | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm1(x) | |
| if self.in_size == self.size: | |
| attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) | |
| x = residual + attn | |
| else: | |
| x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) | |
| if not self.normalize_before: | |
| x = self.norm1(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm2(x) | |
| x = residual + self.feed_forward(x) | |
| if not self.normalize_before: | |
| x = self.norm2(x) | |
| return x, cache | |
| @tables.register("encoder_classes", "SenseVoiceEncoderSmall") | |
| class SenseVoiceEncoderSmall(nn.Module): | |
| """ | |
| Author: Speech Lab of DAMO Academy, Alibaba Group | |
| SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition | |
| https://arxiv.org/abs/2006.01713 | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int = 256, | |
| attention_heads: int = 4, | |
| linear_units: int = 2048, | |
| num_blocks: int = 6, | |
| tp_blocks: int = 0, | |
| dropout_rate: float = 0.1, | |
| positional_dropout_rate: float = 0.1, | |
| attention_dropout_rate: float = 0.0, | |
| stochastic_depth_rate: float = 0.0, | |
| input_layer: Optional[str] = "conv2d", | |
| pos_enc_class=SinusoidalPositionEncoder, | |
| normalize_before: bool = True, | |
| concat_after: bool = False, | |
| positionwise_layer_type: str = "linear", | |
| positionwise_conv_kernel_size: int = 1, | |
| padding_idx: int = -1, | |
| kernel_size: int = 11, | |
| sanm_shfit: int = 0, | |
| selfattention_layer_type: str = "sanm", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self._output_size = output_size | |
| self.embed = SinusoidalPositionEncoder() | |
| self.normalize_before = normalize_before | |
| positionwise_layer = PositionwiseFeedForward | |
| positionwise_layer_args = ( | |
| output_size, | |
| linear_units, | |
| dropout_rate, | |
| ) | |
| encoder_selfattn_layer = MultiHeadedAttentionSANM | |
| encoder_selfattn_layer_args0 = ( | |
| attention_heads, | |
| input_size, | |
| output_size, | |
| attention_dropout_rate, | |
| kernel_size, | |
| sanm_shfit, | |
| ) | |
| encoder_selfattn_layer_args = ( | |
| attention_heads, | |
| output_size, | |
| output_size, | |
| attention_dropout_rate, | |
| kernel_size, | |
| sanm_shfit, | |
| ) | |
| self.encoders0 = nn.ModuleList( | |
| [ | |
| EncoderLayerSANM( | |
| input_size, | |
| output_size, | |
| encoder_selfattn_layer(*encoder_selfattn_layer_args0), | |
| positionwise_layer(*positionwise_layer_args), | |
| dropout_rate, | |
| ) | |
| for i in range(1) | |
| ] | |
| ) | |
| self.encoders = nn.ModuleList( | |
| [ | |
| EncoderLayerSANM( | |
| output_size, | |
| output_size, | |
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |
| positionwise_layer(*positionwise_layer_args), | |
| dropout_rate, | |
| ) | |
| for i in range(num_blocks - 1) | |
| ] | |
| ) | |
| self.tp_encoders = nn.ModuleList( | |
| [ | |
| EncoderLayerSANM( | |
| output_size, | |
| output_size, | |
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |
| positionwise_layer(*positionwise_layer_args), | |
| dropout_rate, | |
| ) | |
| for i in range(tp_blocks) | |
| ] | |
| ) | |
| self.after_norm = LayerNorm(output_size) | |
| self.tp_norm = LayerNorm(output_size) | |
| def output_size(self) -> int: | |
| return self._output_size | |
| def forward( | |
| self, | |
| xs_pad: torch.Tensor, | |
| ilens: torch.Tensor, | |
| ): | |
| """Embed positions in tensor.""" | |
| masks = sequence_mask(ilens, device=ilens.device)[:, None, :] | |
| xs_pad *= self.output_size() ** 0.5 | |
| xs_pad = self.embed(xs_pad) | |
| # forward encoder1 | |
| for layer_idx, encoder_layer in enumerate(self.encoders0): | |
| encoder_outs = encoder_layer(xs_pad, masks) | |
| xs_pad, masks = encoder_outs[0], encoder_outs[1] | |
| for layer_idx, encoder_layer in enumerate(self.encoders): | |
| encoder_outs = encoder_layer(xs_pad, masks) | |
| xs_pad, masks = encoder_outs[0], encoder_outs[1] | |
| xs_pad = self.after_norm(xs_pad) | |
| # forward encoder2 | |
| olens = masks.squeeze(1).sum(1).int() | |
| for layer_idx, encoder_layer in enumerate(self.tp_encoders): | |
| encoder_outs = encoder_layer(xs_pad, masks) | |
| xs_pad, masks = encoder_outs[0], encoder_outs[1] | |
| xs_pad = self.tp_norm(xs_pad) | |
| return xs_pad, olens | |
| @tables.register("model_classes", "SenseVoiceSmall") | |
| class SenseVoiceSmall(nn.Module): | |
| """CTC-attention hybrid Encoder-Decoder model""" | |
| def __init__( | |
| self, | |
| specaug: str = None, | |
| specaug_conf: dict = None, | |
| normalize: str = None, | |
| normalize_conf: dict = None, | |
| encoder: str = None, | |
| encoder_conf: dict = None, | |
| ctc_conf: dict = None, | |
| input_size: int = 80, | |
| vocab_size: int = -1, | |
| ignore_id: int = -1, | |
| blank_id: int = 0, | |
| sos: int = 1, | |
| eos: int = 2, | |
| length_normalized_loss: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| if specaug is not None: | |
| specaug_class = tables.specaug_classes.get(specaug) | |
| specaug = specaug_class(**specaug_conf) | |
| if normalize is not None: | |
| normalize_class = tables.normalize_classes.get(normalize) | |
| normalize = normalize_class(**normalize_conf) | |
| encoder_class = tables.encoder_classes.get(encoder) | |
| encoder = encoder_class(input_size=input_size, **encoder_conf) | |
| encoder_output_size = encoder.output_size() | |
| if ctc_conf is None: | |
| ctc_conf = {} | |
| ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) | |
| self.blank_id = blank_id | |
| self.sos = sos if sos is not None else vocab_size - 1 | |
| self.eos = eos if eos is not None else vocab_size - 1 | |
| self.vocab_size = vocab_size | |
| self.ignore_id = ignore_id | |
| self.specaug = specaug | |
| self.normalize = normalize | |
| self.encoder = encoder | |
| self.error_calculator = None | |
| self.ctc = ctc | |
| self.length_normalized_loss = length_normalized_loss | |
| self.encoder_output_size = encoder_output_size | |
| self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} | |
| self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} | |
| self.textnorm_dict = {"withitn": 14, "woitn": 15} | |
| self.textnorm_int_dict = {25016: 14, 25017: 15} | |
| self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size) | |
| self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004} | |
| self.criterion_att = LabelSmoothingLoss( | |
| size=self.vocab_size, | |
| padding_idx=self.ignore_id, | |
| smoothing=kwargs.get("lsm_weight", 0.0), | |
| normalize_length=self.length_normalized_loss, | |
| ) | |
| @staticmethod | |
| def from_pretrained(model:str=None, **kwargs): | |
| from funasr import AutoModel | |
| model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) | |
| return model, kwargs | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| **kwargs, | |
| ): | |
| """Encoder + Decoder + Calc loss | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| text: (Batch, Length) | |
| text_lengths: (Batch,) | |
| """ | |
| # import pdb; | |
| # pdb.set_trace() | |
| if len(text_lengths.size()) > 1: | |
| text_lengths = text_lengths[:, 0] | |
| if len(speech_lengths.size()) > 1: | |
| speech_lengths = speech_lengths[:, 0] | |
| batch_size = speech.shape[0] | |
| # 1. Encoder | |
| encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text) | |
| loss_ctc, cer_ctc = None, None | |
| loss_rich, acc_rich = None, None | |
| stats = dict() | |
| loss_ctc, cer_ctc = self._calc_ctc_loss( | |
| encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4 | |
| ) | |
| loss_rich, acc_rich = self._calc_rich_ce_loss( | |
| encoder_out[:, :4, :], text[:, :4] | |
| ) | |
| loss = loss_ctc + loss_rich | |
| # Collect total loss stats | |
| stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None | |
| stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None | |
| stats["loss"] = torch.clone(loss.detach()) if loss is not None else None | |
| stats["acc_rich"] = acc_rich | |
| # force_gatherable: to-device and to-tensor if scalar for DataParallel | |
| if self.length_normalized_loss: | |
| batch_size = int((text_lengths + 1).sum()) | |
| loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) | |
| return loss, stats, weight | |
| def encode( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| text: torch.Tensor, | |
| **kwargs, | |
| ): | |
| """Frontend + Encoder. Note that this method is used by asr_inference.py | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| ind: int | |
| """ | |
| # Data augmentation | |
| if self.specaug is not None and self.training: | |
| speech, speech_lengths = self.specaug(speech, speech_lengths) | |
| # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN | |
| if self.normalize is not None: | |
| speech, speech_lengths = self.normalize(speech, speech_lengths) | |
| lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device) | |
| language_query = self.embed(lids) | |
| styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device) | |
| style_query = self.embed(styles) | |
| speech = torch.cat((style_query, speech), dim=1) | |
| speech_lengths += 1 | |
| event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1) | |
| input_query = torch.cat((language_query, event_emo_query), dim=1) | |
| speech = torch.cat((input_query, speech), dim=1) | |
| speech_lengths += 3 | |
| encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) | |
| return encoder_out, encoder_out_lens | |
| def _calc_ctc_loss( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_out_lens: torch.Tensor, | |
| ys_pad: torch.Tensor, | |
| ys_pad_lens: torch.Tensor, | |
| ): | |
| # Calc CTC loss | |
| loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) | |
| # Calc CER using CTC | |
| cer_ctc = None | |
| if not self.training and self.error_calculator is not None: | |
| ys_hat = self.ctc.argmax(encoder_out).data | |
| cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) | |
| return loss_ctc, cer_ctc | |
| def _calc_rich_ce_loss( | |
| self, | |
| encoder_out: torch.Tensor, | |
| ys_pad: torch.Tensor, | |
| ): | |
| decoder_out = self.ctc.ctc_lo(encoder_out) | |
| # 2. Compute attention loss | |
| loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous()) | |
| acc_rich = th_accuracy( | |
| decoder_out.view(-1, self.vocab_size), | |
| ys_pad.contiguous(), | |
| ignore_label=self.ignore_id, | |
| ) | |
| return loss_rich, acc_rich | |
| def inference( | |
| self, | |
| data_in, | |
| data_lengths=None, | |
| key: list = ["wav_file_tmp_name"], | |
| tokenizer=None, | |
| frontend=None, | |
| **kwargs, | |
| ): | |
| meta_data = {} | |
| if ( | |
| isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" | |
| ): # fbank | |
| speech, speech_lengths = data_in, data_lengths | |
| if len(speech.shape) < 3: | |
| speech = speech[None, :, :] | |
| if speech_lengths is None: | |
| speech_lengths = speech.shape[1] | |
| else: | |
| # extract fbank feats | |
| time1 = time.perf_counter() | |
| audio_sample_list = load_audio_text_image_video( | |
| data_in, | |
| fs=frontend.fs, | |
| audio_fs=kwargs.get("fs", 16000), | |
| data_type=kwargs.get("data_type", "sound"), | |
| tokenizer=tokenizer, | |
| ) | |
| time2 = time.perf_counter() | |
| meta_data["load_data"] = f"{time2 - time1:0.3f}" | |
| speech, speech_lengths = extract_fbank( | |
| audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend | |
| ) | |
| time3 = time.perf_counter() | |
| meta_data["extract_feat"] = f"{time3 - time2:0.3f}" | |
| meta_data["batch_data_time"] = ( | |
| speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 | |
| ) | |
| speech = speech.to(device=kwargs["device"]) | |
| speech_lengths = speech_lengths.to(device=kwargs["device"]) | |
| language = kwargs.get("language", "auto") | |
| language_query = self.embed( | |
| torch.LongTensor( | |
| [[self.lid_dict[language] if language in self.lid_dict else 0]] | |
| ).to(speech.device) | |
| ).repeat(speech.size(0), 1, 1) | |
| use_itn = kwargs.get("use_itn", False) | |
| textnorm = kwargs.get("text_norm", None) | |
| if textnorm is None: | |
| textnorm = "withitn" if use_itn else "woitn" | |
| textnorm_query = self.embed( | |
| torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) | |
| ).repeat(speech.size(0), 1, 1) | |
| speech = torch.cat((textnorm_query, speech), dim=1) | |
| speech_lengths += 1 | |
| event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( | |
| speech.size(0), 1, 1 | |
| ) | |
| input_query = torch.cat((language_query, event_emo_query), dim=1) | |
| speech = torch.cat((input_query, speech), dim=1) | |
| speech_lengths += 3 | |
| # Encoder | |
| encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) | |
| if isinstance(encoder_out, tuple): | |
| encoder_out = encoder_out[0] | |
| # c. Passed the encoder result and the beam search | |
| ctc_logits = self.ctc.log_softmax(encoder_out) | |
| if kwargs.get("ban_emo_unk", False): | |
| ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf") | |
| results = [] | |
| b, n, d = encoder_out.size() | |
| if isinstance(key[0], (list, tuple)): | |
| key = key[0] | |
| if len(key) < b: | |
| key = key * b | |
| for i in range(b): | |
| x = ctc_logits[i, : encoder_out_lens[i].item(), :] | |
| yseq = x.argmax(dim=-1) | |
| yseq = torch.unique_consecutive(yseq, dim=-1) | |
| ibest_writer = None | |
| if kwargs.get("output_dir") is not None: | |
| if not hasattr(self, "writer"): | |
| self.writer = DatadirWriter(kwargs.get("output_dir")) | |
| ibest_writer = self.writer[f"1best_recog"] | |
| mask = yseq != self.blank_id | |
| token_int = yseq[mask].tolist() | |
| # Change integer-ids to tokens | |
| text = tokenizer.decode(token_int) | |
| result_i = {"key": key[i], "text": text} | |
| results.append(result_i) | |
| if ibest_writer is not None: | |
| ibest_writer["text"][key[i]] = text | |
| return results, meta_data | |
| def export(self, **kwargs): | |
| from export_meta import export_rebuild_model | |
| if "max_seq_len" not in kwargs: | |
| kwargs["max_seq_len"] = 512 | |
| models = export_rebuild_model(model=self, **kwargs) | |
| return models |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment