-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Closed
Labels
Description
System Info
transformersversion: 4.56.1- Platform: Linux-5.10.0-1.oe.x86_64-x86_64-with-glibc2.34
- Python version: 3.12.11
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.6.2
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: yes
- Using GPU in script?: yes
- GPU type: NVIDIA H800
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
The code I used is as follows:
from transformers import AutoModelForCausalLM
import torch
if __name__ == "__main__":
torch.distributed.init_process_group("nccl")
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mixtral-8x22B-Instruct-v0.1",
tp_plan="auto",
tp_size=8,
dtype="auto",
)
torch.distributed.destroy_process_group()I run the script with the command:
torchrun --nproc-per-node=8 load.pyExpected behavior
Load the model(~262G) with 8×H800(80G) should not lead to OOM