Skip to content

OOM when setting batchtype='token' and size mismatch when setting batchtype='example' #2674

@OedonLestrange42

Description

@OedonLestrange42

Notice: In order to resolve issues more efficiently, please raise issue following the template.
(注意:为了更加高效率解决您遇到的问题,请按照模板提问,补充细节)

🐛 Bug

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd (example finetune.sh, I only change the number of datasets chunks and batchtype):
model_name_or_model_dir="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
...
torchrun $DISTRIBUTED_ARGS \
../FunASR/funasr/bin/train_ds.py \
++model="${model_name_or_model_dir}" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
++dataset="AudioDatasetHotword" \
++dataset_conf.index_ds="IndexDSJsonl" \
++dataset_conf.batch_size=32 \
++dataset_conf.sort_size=500 \
++dataset_conf.batch_type="example" \
++dataset_conf.batch_sampler="BatchSampler" \
++dataset_conf.min_token_length=10 \
++dataset_conf.max_token_length=2000 \
++dataset_conf.int_pad_value=-1 \
++dataset_conf.float_pad_value=0.0 \
++dataset_conf.num_workers=4 \
++dataset_conf.data_split_num=${DATA_SPLIT_NUM} \
++train_conf.accum_grad=4 \
++train_conf.max_epoch=20 \
++train_conf.log_interval=1 \
++train_conf.use_fp16=true \
++train_conf.resume=true \
++train_conf.validate_interval=7000 \
++train_conf.save_checkpoint_interval=7000 \
++train_conf.keep_nbest_models=20 \
++train_conf.avg_nbest_model=10 \
++train_conf.use_deepspeed=false \
++deepspeed_config=${deepspeed_config} \
++scheduler=warmuplr \
++scheduler_conf.warmup_steps=1500 \
++optim_conf.lr=0.0002 \
++output_dir="${output_dir}" 2>&1\
  1. See error
Traceback (most recent call last):                                                                                                              
  File "/data/home/fj/workspace/wanghd_wkdir/../FunASR/funasr/bin/train_ds.py", line 244, in <module>                                           
    main_hydra()                                                                                                                                
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main                                  
    _run_hydra(                                                                                                                                 
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra                          
    _run_app(                                                                                                                                   
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app                            
    run_and_report(                                                                                                                             
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report                      
    raise ex                                                                                                                                    
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report                      
    return func()                                                                                                                               
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>                            
    lambda: hydra.run(                                                                                                                          
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run                                 
    _ = ret.return_value                                                                                                                        
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value                             
    raise self._return_value                                                                                                                    
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job                                  
    ret.return_value = task_function(task_cfg)                                                                                                  
  File "/data/home/fj/workspace/wanghd_wkdir/../FunASR/funasr/bin/train_ds.py", line 56, in main_hydra                                          
    main(**kwargs)                                                                                                                              
  File "/data/home/fj/workspace/wanghd_wkdir/../FunASR/funasr/bin/train_ds.py", line 177, in main                                               
    trainer.train_epoch(                                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/train_utils/trainer_ds.py", line 603, in train_epoch                                              
    self.forward_step(model, batch, loss_dict=loss_dict)                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/train_utils/trainer_ds.py", line 670, in forward_step                                             
    retval = model(**batch)                                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)      
  File "/data/home/fj/workspace/FunASR/funasr/models/bicif_paraformer/model.py", line 184, in forward
    loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  File "/data/home/fj/workspace/FunASR/funasr/models/bicif_paraformer/model.py", line 96, in _calc_att_loss
    sematic_embeds, decoder_out_1st = self.sampler(
  File "/data/home/fj/workspace/FunASR/funasr/models/paraformer/model.py", line 350, in sampler
    decoder_outs = self.decoder(
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/bicif_paraformer/model.py", line 184, in forward                                           
    loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/bicif_paraformer/model.py", line 96, in _calc_att_loss                                     
    sematic_embeds, decoder_out_1st = self.sampler(                                                                                             
  File "/data/home/fj/workspace/FunASR/funasr/models/paraformer/model.py", line 350, in sampler                                                 
    decoder_outs = self.decoder(                                                                                                                
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/e_paraformer/decoder.py", line 397, in forward                                             
    x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)                                                       
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/transformer/utils/repeat.py", line 32, in forward                                          
    args = m(*args)                                                                                                                             
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/e_paraformer/decoder.py", line 106, in forward                                             
    x, _ = self.self_attn(tgt, tgt_mask)                                                                                                        
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/sanm/attention.py", line 518, in forward                                                   
    inputs = inputs * mask                                                                                                                      
RuntimeError: The size of tensor a (0) must match the size of tensor b (31) at non-singleton dimension 1
  1. Bug2:
    i run the finetune with batchtype='token' and small datasets chunks (each chunk can only fullfill 2 batches):
torchrun $DISTRIBUTED_ARGS \
../FunASR/funasr/bin/train_ds.py \
++model="${model_name_or_model_dir}" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
++dataset="AudioDatasetHotword" \
++dataset_conf.index_ds="IndexDSJsonl" \
++dataset_conf.batch_size=2000 \
++dataset_conf.sort_size=500 \
++dataset_conf.batch_type="token" \
++dataset_conf.batch_sampler="BatchSampler" \
++dataset_conf.min_token_length=10 \
++dataset_conf.max_token_length=2000 \
++dataset_conf.int_pad_value=-1 \
++dataset_conf.float_pad_value=0.0 \
++dataset_conf.num_workers=4 \
++dataset_conf.data_split_num=256 \
++train_conf.use_wandb=true \
++train_conf.accum_grad=4 \
++train_conf.max_epoch=20 \
++train_conf.log_interval=1 \
++train_conf.use_fp16=true \
++train_conf.resume=true \
++train_conf.validate_interval=7000 \
++train_conf.save_checkpoint_interval=7000 \
++train_conf.keep_nbest_models=20 \
++train_conf.avg_nbest_model=10 \
++train_conf.use_deepspeed=false \
++deepspeed_config=${deepspeed_config} \
++scheduler=warmuplr \
++scheduler_conf.warmup_steps=1500 \
++optim_conf.lr=0.0002 \
++output_dir="${output_dir}" 2>&1\
  1. See Error:
[2025-09-18 16:54:47,998][root][INFO] - total_num of samplers: 310, ./train_jsonls_large/all_filtered_t1_train.list                             
[2025-09-18 16:54:47,998][root][INFO] - Train epoch: 0, rank: 0                                                                                 
                                                                                                                                                
[2025-09-18 16:54:48,006][root][INFO] - rank: 0, dataloader start from step: 0, batch_num: 2, after: 2                                          
[2025-09-18 16:54:48,136][root][INFO] - rank: 0, dataloader start from step: 0, batch_num: 2, after: 2  

...
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/transformer/utils/repeat.py", line 32, in forward                                          
    args = m(*args)                                                                                                                             
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/sanm/encoder.py", line 122, in forward                                                     
    self.self_attn(                                                                                                                             
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl               
    return self._call_impl(*args, **kwargs)                                                                                                     
  File "/home/fj/anaconda3/envs/funasr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                       
    return forward_call(*args, **kwargs)                                                                                                        
  File "/data/home/fj/workspace/FunASR/funasr/models/sanm/attention.py", line 310, in forward                                                   
    att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)                                                                
  File "/data/home/fj/workspace/FunASR/funasr/models/sanm/attention.py", line 278, in forward_attention                                         
    attn = torch.softmax(scores, dim=-1).masked_fill(                                                                                           
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 25.88 MiB is free. Inc
luding non-PyTorch memory, this process has 23.61 GiB memory in use. Of the allocated memory 22.45 GiB is allocated by PyTorch, and 874.34 MiB i
s reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 
to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) 

Environment

  • OS : Ubuntu 20.04
  • FunASR Version : 1.2.7
  • ModelScope Version :
  • PyTorch Version : 2.6.0+cu124
  • How you installed funasr (pip, source): source
  • Python version: 3.10
  • GPU: NVIDIA GeForce RTX 3090
  • CUDA/cuDNN version : 12.4

Additional context

My training data consists of 79000 wav files, and each of them is no longer than 10 seconds. I tried to split the datasets into small chunks (only 2 2000-token batch each chunk), but it kept triggering OOM error, and i ensure the GPU is completely empty (24529MB available). Only when i split the datasets into tiny chunks which CANNOT even fullfill one 2000-token batch, can i run the script.

I thought the OOM error is due to batchtype='token', which may require dataloading and sorting, thus i change it to 'example'. But the batchtype='example' encounterred size mismatch error.

Could you please offer me a device requirements for paraformer finetuing or did i miss something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions