diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index 849dae5d..2adefe9a 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -4,6 +4,7 @@ import onnx import onnxsim +import onnxslim import torch import yaml @@ -344,9 +345,11 @@ def _perform_spk_mix(self, spk_mix: Dict[str, float]): return spk_mix_embed def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto: - print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...') - fs2, check = onnxsim.simplify(fs2, include_subgraph=True) - assert check, 'Simplified ONNX model could not be validated' + # print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...') + # fs2, check = onnxsim.simplify(fs2, include_subgraph=True) + # assert check, 'Simplified ONNX model could not be validated' + print(f'Running OnnxSlim on {self.fs2_aux_class_name}...') + fs2 = onnxslim.slim(fs2) onnx_helper.model_reorder_io_list( fs2, 'input', target_name='languages', insert_after_name='tokens' diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 82808ec0..d0003dea 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -4,6 +4,7 @@ import onnx import onnxsim +import onnxslim import torch import yaml @@ -82,6 +83,8 @@ def __init__( if self.freeze_spk is not None: self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) + self.use_melody_encoder = hparams['use_melody_encoder'] + def build_model(self) -> DiffSingerVarianceONNX: model = DiffSingerVarianceONNX( vocab_size=len(self.phoneme_dictionary), @@ -649,9 +652,11 @@ def _optimize_linguistic_graph(self, linguistic: onnx.ModelProto) -> onnx.ModelP 'encoder_out': (1, 'n_tokens', hparams['hidden_size']) } ) - print(f'Running ONNX Simplifier on {self.fs2_class_name}...') - linguistic, check = onnxsim.simplify(linguistic, include_subgraph=True) - assert check, 'Simplified ONNX model could not be validated' + # print(f'Running ONNX Simplifier on {self.fs2_class_name}...') + # linguistic, check = onnxsim.simplify(linguistic, include_subgraph=True) + # assert check, 'Simplified ONNX model could not be validated' + print(f'Running OnnxSlim on {self.fs2_class_name}...') + linguistic = onnxslim.slim(linguistic) onnx_helper.model_reorder_io_list( linguistic, 'input', target_name='languages', insert_after_name='tokens' @@ -678,8 +683,11 @@ def _optimize_merge_pitch_predictor_graph( onnx_helper.model_override_io_shapes( pitch_pre, output_shapes={'pitch_cond': (1, 'n_frames', hparams['hidden_size'])} ) - pitch_pre, check = onnxsim.simplify(pitch_pre, include_subgraph=True) - assert check, 'Simplified ONNX model could not be validated' + if self.use_melody_encoder: + pitch_pre = onnxslim.slim(pitch_pre) + else: + pitch_pre, check = onnxsim.simplify(pitch_pre, include_subgraph=True) + assert check, 'Simplified ONNX model could not be validated' onnx_helper.model_override_io_shapes( pitch_predictor, output_shapes={'pitch_pred': (1, 'n_frames')} diff --git a/requirements-onnx.txt b/requirements-onnx.txt index dda53148..4eea23b0 100644 --- a/requirements-onnx.txt +++ b/requirements-onnx.txt @@ -12,6 +12,7 @@ MonkeyType==23.3.0 numpy<2.0.0 onnx~=1.16.0 onnxsim~=0.4.36 +onnxslim praat-parselmouth==0.4.3 pyworld==0.3.4 PyYAML diff --git a/requirements.txt b/requirements.txt index 8f79e238..6a96ee6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ MonkeyType==23.3.0 numpy<2.0.0 onnx~=1.16.0 onnxsim~=0.4.36 +onnxslim praat-parselmouth==0.4.3 pyworld==0.3.4 PyYAML