-
-
Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
| #!/usr/bin/env python3 | |
| """Convert TTS to ONNX | |
| Using ESPnet. | |
| Test command: | |
| python convert_tts2onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits | |
| """ | |
| import argparse | |
| import logging | |
| import sys | |
| import numpy as np | |
| import torch | |
| import time | |
| from typing import Dict | |
| from typing import Optional | |
| from espnet2.bin.tts_inference import Text2Speech | |
| from espnet2.utils.types import str_or_none | |
| import torch.nn.functional as F | |
| def get_parser(): | |
| parser = argparse.ArgumentParser( | |
| description="", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--tts-tag", | |
| required=True, | |
| type=str, | |
| help="TTS tag (or Directory) for model located at huggingface/zenodo/local" | |
| ) | |
| return parser | |
| ### Add this at espnet2/gan_tts/vits/vits.py | |
| def inference_onnx( | |
| self, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| sids: Optional[torch.Tensor] = None, | |
| spembs: Optional[torch.Tensor] = None, | |
| lids: Optional[torch.Tensor] = None, | |
| durations: Optional[torch.Tensor] = None, | |
| noise_scale: float = 0.667, | |
| noise_scale_dur: float = 0.8, | |
| alpha: float = 1.0, | |
| max_len: Optional[int] = None, | |
| use_teacher_forcing: bool = False, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Run inference for ONNX. | |
| """ | |
| if sids is not None: | |
| sids = sids.view(1) | |
| if lids is not None: | |
| lids = lids.view(1) | |
| if durations is not None: | |
| durations = durations.view(1, 1, -1) | |
| # inference | |
| if use_teacher_forcing: | |
| raise NotImplementedError | |
| else: | |
| wav, _, _ = self.generator.inference( | |
| text=text, | |
| text_lengths=text_lengths, | |
| sids=sids, | |
| spembs=spembs, | |
| lids=lids, | |
| dur=durations, | |
| noise_scale=noise_scale, | |
| noise_scale_dur=noise_scale_dur, | |
| alpha=alpha, | |
| max_len=max_len, | |
| ) | |
| return wav.view(-1) | |
| def test_onnx(): | |
| logging.info('Test ONNX') | |
| import onnxruntime as ort | |
| this_text = 'Hello world, how are you doing' | |
| this_text = preprocessing("<dummy>", dict(text=this_text))['text'] | |
| this_text = this_text[None] | |
| # this_len = np.array([this_text.shape[1]], dtype=int) | |
| ort_sess = ort.InferenceSession('tts_model.onnx') | |
| inname = [input.name for input in ort_sess.get_inputs()] | |
| outname = [output.name for output in ort_sess.get_outputs()] | |
| logging.info("inputs name: %s || outputs name: %s", inname, outname) | |
| outputs = ort_sess.run(None, {'input_text': this_text}) | |
| logging.info(type(outputs)) | |
| if __name__ == "__main__": | |
| # Logger | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" | |
| logging.basicConfig(filename='onnx.log', encoding='utf-8', level=logging.INFO, format=logfmt) | |
| # Load Pretrained model and testing wav generation | |
| logging.info("Preparing pretrained model from: %s", args.tts_tag) | |
| text2speech = Text2Speech.from_pretrained( | |
| model_tag=str_or_none(args.tts_tag), | |
| vocoder_tag=None, | |
| device="cuda", | |
| # Only for Tacotron 2 & Transformer | |
| threshold=0.5, | |
| # Only for Tacotron 2 | |
| minlenratio=0.0, | |
| maxlenratio=10.0, | |
| use_att_constraint=False, | |
| backward_window=1, | |
| forward_window=3, | |
| # Only for FastSpeech & FastSpeech2 & VITS | |
| speed_control_alpha=1.0, | |
| # Only for VITS | |
| noise_scale=0.667, | |
| noise_scale_dur=0.8, | |
| ) | |
| text = 'Hello world' | |
| logging.info("Generating test wav using the sequence: %s", text) | |
| with torch.no_grad(): | |
| start = time.time() | |
| wav = text2speech(text)["wav"] | |
| rtf = (time.time() - start) / (len(wav) / text2speech.fs) | |
| logging.info(f"RTF = {rtf:5f}") | |
| # Prepare modules for conversion | |
| logging.info("Generate ONNX models") | |
| with torch.no_grad(): | |
| device = text2speech.device | |
| preprocessing = text2speech.preprocess_fn | |
| model_tts = text2speech.tts | |
| # Replace forward with inference to avoid problems at ONNX generation | |
| model_tts.forward = model_tts.inference_onnx | |
| # Preprocessing data | |
| preproc_text = preprocessing("<dummy>", dict(text=text))['text'] | |
| preproc_text = torch.from_numpy(preproc_text).to(device).unsqueeze(0) | |
| text_lengths = torch.tensor( | |
| [preproc_text.size(1)], | |
| dtype=torch.long, | |
| device=preproc_text.device, | |
| ) | |
| wav = model_tts(preproc_text, text_lengths) | |
| logging.info(wav.shape) | |
| inputs = (preproc_text, text_lengths) | |
| # Generate TTS Model | |
| torch.onnx.export( | |
| model_tts, | |
| inputs, | |
| 'tts_model.onnx', | |
| export_params=True, | |
| opset_version=13, | |
| do_constant_folding=True, | |
| verbose=True, | |
| input_names=['input_text'], | |
| output_names=['wav'], | |
| dynamic_axes={ | |
| 'input_text': { | |
| 1: 'length' | |
| }, | |
| 'wav': { | |
| 0: 'length' | |
| } | |
| } | |
| ) | |
| test_onnx() | |
| sys.exit(0) |
@Fhrozen I am getting this error:
AttributeError Traceback (most recent call last)
/tmp/ipykernel_1834/548140431.py in
5
6 # Replace forward with inference to avoid problems at ONNX generation
----> 7 model_tts.forward = model_tts.inference_onnx
8
9 # Preprocessing data
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in getattr(self, name)
1176 return modules[name]
1177 raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1178 type(self).name, name))
1179
1180 def setattr(self, name: str, value: Union[Tensor, 'Module']) -> None:
AttributeError: 'VITS' object has no attribute 'inference_onnx'
Hi @Fhrozen did you get a chance to look into it?
@sciai-ai, Sorry. I need a little longer time. I expect to implement the fixes for this or at last next weekend.
Hi @Fhrozen did you get a chance to work on it?
Sorry for the late response. I am checking for solutions . I will update u once finished
@Fhrozen Any updates?
Check this pls: https://github.com/Masao-Someki/espnet_onnx
@Fhrozen I also encountered this error, how did you solve it ?
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19
L83-L97