-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Description
System Info
transformers
version: 4.53.2- Platform: Linux-5.10.225-213.878.amzn2.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.33.4
- Safetensors version: 0.5.3
- Accelerate version: 1.4.0
- Accelerate config: not found
- DeepSpeed version: 0.17.2
- PyTorch version (accelerator?): 2.6.0 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA A100-SXM4-80GB
Who can help?
I'm trying to train a qwen model with long context (4096) with paged attention. When I switch the attention to paged , it immediately errors out:
Here is traceback:
Traceback (most recent call last):
File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 251, in
main()
File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 247, in main
trainer.train()
File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2206, in train
return inner_training_loop(
File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 3743, in training_step
inputs = self._prepare_inputs(inputs)
File "/miniconda/lib/python3.10/site-packages/trl/extras/profiling.py", line 98, in wrapper
return func(self, *args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 990, in _prepare_inputs
generation_batch = self._generate_and_score_completions(generation_batch)
File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 1174, in _generate_and_score_completions
prompt_completion_ids = unwrapped_model.generate(
File "/miniconda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2625, in generate
result = self._sample(
File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3606, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/miniconda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
output = func(self, *args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 570, in forward
outputs: BaseModelOutputWithPast = self.model(
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
output = func(self, *args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 458, in forward
layer_outputs = decoder_layer(
File "/miniconda/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in call
return super().call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 262, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 217, in forward
attn_output, attn_weights = attention_interface(
File "/miniconda/lib/python3.10/site-packages/transformers/integrations/flash_paged.py", line 47, in paged_attention_forward
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
AttributeError: 'NoneType' object has no attribute 'update'
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 251, in
[rank0]: main()
[rank0]: File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 247, in main
[rank0]: trainer.train()
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2206, in train
[rank0]: return inner_training_loop(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 3743, in training_step
[rank0]: inputs = self._prepare_inputs(inputs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/trl/extras/profiling.py", line 98, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 990, in _prepare_inputs
[rank0]: generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]: File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 1174, in _generate_and_score_completions
[rank0]: prompt_completion_ids = unwrapped_model.generate(
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2625, in generate
[rank0]: result = self._sample(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3606, in _sample
[rank0]: outputs = self(**model_inputs, return_dict=True)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]: return model_forward(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in call
[rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]: return func(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 570, in forward
[rank0]: outputs: BaseModelOutputWithPast = self.model(
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 458, in forward
[rank0]: layer_outputs = decoder_layer(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in call
[rank0]: return super().call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 262, in forward
[rank0]: hidden_states, self_attn_weights = self.self_attn(
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 217, in forward
[rank0]: attn_output, attn_weights = attention_interface(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/integrations/flash_paged.py", line 47, in paged_attention_forward
[rank0]: k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
[rank0]: AttributeError: 'NoneType' object has no attribute 'update'
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
train_grpo.py
import sys
import argparse
import wandb
import torch
import os
sys.path.append("/mnt/task_runtime/")
from datasets import load_dataset
try:
from trl import GRPOConfig, GRPOTrainer
except:
from trl.trl import GRPOConfig, GRPOTrainer
def get_world_size():
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1 # fallback for non-distributed run
def sanitize_for_wandb(name):
return name.replace("/", "-").replace(":", "-")
def get_args():
parser = argparse.ArgumentParser(description="Train with GRPO")
parser.add_argument(
"--model_name",
type=str,
default="Qwen/Qwen2-0.5B-Instruct",
required=True,
help="Model name or path",
)
parser.add_argument(
"--ckpt_path", type=str, default=None, help="Optional checkpoint to resume from"
)
parser.add_argument(
"--rewards",
type=str,
default="efficient_reasoning",
help="Comma-separated reward functions ",
)
# add reward weights argument
parser.add_argument(
"--reward_weights",
type=str,
default=None,
help="Comma-separated weights for each reward function",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
default="output_token_space_efficient_reasoning",
help="Output directory for training artifacts",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="Dataset name (e.g., gsm8k, trl-lib/tldr)",
)
parser.add_argument(
"--dataset_split",
type=int,
default=1024,
help="Subset of training data to use (optional)",
)
parser.add_argument(
"--use_peft",
action="store_true",
help="only train Lora",
)
parser.add_argument(
"--use_quant",
action="store_true",
help="load base model in 4/8 bits",
)
parser.add_argument(
"--use_fsdp",
action="store_true",
help="Use Fully Sharded Data Parallel (FSDP) for distributed training",
)
parser.add_argument(
"--use_attention",
type=str,
default="flash_attention_2",
help="Attention implementation to use (e.g., flash_attention_2, default)",
)
parser.add_argument(
"--max_context_length",
type=int,
default=1024,
help="Maximum context length",
)
parser.add_argument(
"--dataset_difficulty",
type=str,
default=None,
help="Difficulty level of the dataset (e.g., easy, medium, hard)",
)
parser.add_argument(
"--expt_tag",
type=str,
default="",
help="Experiment tag for wandb",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=4,
help="Batch size per device for training",
)
parser.add_argument(
"--use_unsloth",
action="store_true",
help="Use Unsloth for fast inference",
)
return parser.parse_args()
def main():
args = get_args()
init_distributed()
torch.backends.cuda.enable_mem_efficient_sdp(True)
model = get_model(
args.model_name,
use_peft=args.use_peft,
use_quant=args.use_quant,
use_attention=args.use_attention,
ckpt_path=args.ckpt_path,
use_unsloth=args.use_unsloth,
)
train_dataset, val_datasets, eval_datasets = get_grpo_dataset(
dataset_name=args.dataset_name,
n_shot=0, # Set to 0 for zero-shot training
)
rank0_print(f"Using datasets {train_dataset} and {val_datasets}")
# if torch.distributed.get_rank() == 0:
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
project_name = sanitize_for_wandb(
f"{args.expt_tag}_GRPO_{'PEFT' if args.use_peft else ''}_{args.model_name}_{args.dataset_name if args.dataset_name else args.dataset_difficulty}"
)
wandb.init(
project=project_name, # change to your wandb project
name=args.expt_tag + args.output_dir, # optional run name
)
per_device_train_batch_size = args.per_device_train_batch_size
num_processes = get_world_size()
gradient_accumulation_steps = 1
effective_batch_size = (
per_device_train_batch_size * num_processes * gradient_accumulation_steps
)
for name, param in model.named_parameters():
if param.requires_grad:
rank0_print(f"Trainable parameter: {name}, shape: {param.shape}")
fsdp_config = dict(
fsdp="full_shard auto_wrap",
auto_wrap_policy=custom_wrap_policy_qwen,
)
training_args = GRPOConfig(
output_dir=args.output_dir,
logging_steps=10,
num_train_epochs=3,
per_device_train_batch_size=per_device_train_batch_size,
save_steps=100,
report_to="wandb",
eval_steps=1000 if val_datasets else None,
eval_strategy="steps" if val_datasets else "no",
num_generations=effective_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
per_device_eval_batch_size=per_device_train_batch_size,
max_completion_length=args.max_context_length,
fsdp_config=fsdp_config if args.use_fsdp else None,
# bf16=True, # Enable bf16 if supported
# fp16=True, # Enable fp16 for training
)
rewards = [REWARD_FUNC_MAP[reward] for reward in args.rewards.split(",")]
if args.reward_weights:
reward_weights = list(map(float, args.reward_weights.split(",")))
training_args.reward_weights = reward_weights
# model.gradient_checkpointing_enable()
torch.backends.cuda.enable_mem_efficient_sdp(True)
trainer = GRPOTrainer(
model=model,
reward_funcs=rewards,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_datasets[0] if val_datasets else None,
)
trainer.train()
if name == "main":
main()
Expected behavior
specifying attention implementation as pages attention should not change anything.