Skip to content

Wav2Vec2Model can't be exported with torch.onnx.export() #4116

@Tomas542

Description

@Tomas542

🐛 Describe the bug

An attempt to export wav2vec model with this pytorch guide fails. In Wav2Vec2 in forward call something wrong with layer_drop. If to take only main part from logs:

x = self.encoder(x, lengths)
x = self.transformer(x, attention_mask=mask)
if not (self.training and torch.rand(1).item() <= self.layer_drop):

And this is with all dropout = 0 by default. Code for reproduction

from torchaudio.models import wav2vec2_xlsr_300m
import torch

torch_model = wav2vec2_xlsr_300m()
example_inputs = (torch.rand(1, 64600),)
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True)

Versions

PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0

Versions of relevant libraries:
[pip3] numpy==2.3.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.8.0
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.8.0
[pip3] triton==3.4.0
[pip3] tritonclient==2.36.0
[conda] numpy 2.3.3 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.3 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.8.0 pypi_0 pypi
[conda] torch-tb-profiler 0.4.3 pypi_0 pypi
[conda] torchaudio 2.8.0 pypi_0 pypi
[conda] triton 3.4.0 pypi_0 pypi
[conda] tritonclient 2.36.0 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions