Skip to content

paged attention NOT working with Qwen Models #39525

@NickNickGo

Description

@NickNickGo

System Info

  • transformers version: 4.53.2
  • Platform: Linux-5.10.225-213.878.amzn2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.33.4
  • Safetensors version: 0.5.3
  • Accelerate version: 1.4.0
  • Accelerate config: not found
  • DeepSpeed version: 0.17.2
  • PyTorch version (accelerator?): 2.6.0 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

I'm trying to train a qwen model with long context (4096) with paged attention. When I switch the attention to paged , it immediately errors out:

Here is traceback:

Traceback (most recent call last):
File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 251, in
main()
File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 247, in main
trainer.train()
File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2206, in train
return inner_training_loop(
File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 3743, in training_step
inputs = self._prepare_inputs(inputs)
File "/miniconda/lib/python3.10/site-packages/trl/extras/profiling.py", line 98, in wrapper
return func(self, *args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 990, in _prepare_inputs
generation_batch = self._generate_and_score_completions(generation_batch)
File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 1174, in _generate_and_score_completions
prompt_completion_ids = unwrapped_model.generate(
File "/miniconda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2625, in generate
result = self._sample(
File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3606, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/miniconda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
output = func(self, *args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 570, in forward
outputs: BaseModelOutputWithPast = self.model(
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
output = func(self, *args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 458, in forward
layer_outputs = decoder_layer(
File "/miniconda/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in call
return super().call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 262, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 217, in forward
attn_output, attn_weights = attention_interface(
File "/miniconda/lib/python3.10/site-packages/transformers/integrations/flash_paged.py", line 47, in paged_attention_forward
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
AttributeError: 'NoneType' object has no attribute 'update'
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 251, in
[rank0]: main()
[rank0]: File "/mnt/task_runtime/mango/train/grpo/train_grpo_token_space_efficient_reasoning.py", line 247, in main
[rank0]: trainer.train()
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2206, in train
[rank0]: return inner_training_loop(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/trainer.py", line 3743, in training_step
[rank0]: inputs = self._prepare_inputs(inputs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/trl/extras/profiling.py", line 98, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 990, in _prepare_inputs
[rank0]: generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]: File "/miniconda/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 1174, in _generate_and_score_completions
[rank0]: prompt_completion_ids = unwrapped_model.generate(
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2625, in generate
[rank0]: result = self._sample(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3606, in _sample
[rank0]: outputs = self(**model_inputs, return_dict=True)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]: return model_forward(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in call
[rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]: return func(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 570, in forward
[rank0]: outputs: BaseModelOutputWithPast = self.model(
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 458, in forward
[rank0]: layer_outputs = decoder_layer(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in call
[rank0]: return super().call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 262, in forward
[rank0]: hidden_states, self_attn_weights = self.self_attn(
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 217, in forward
[rank0]: attn_output, attn_weights = attention_interface(
[rank0]: File "/miniconda/lib/python3.10/site-packages/transformers/integrations/flash_paged.py", line 47, in paged_attention_forward
[rank0]: k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
[rank0]: AttributeError: 'NoneType' object has no attribute 'update'

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

train_grpo.py

import sys
import argparse
import wandb
import torch
import os

sys.path.append("/mnt/task_runtime/")
from datasets import load_dataset

try:
from trl import GRPOConfig, GRPOTrainer
except:
from trl.trl import GRPOConfig, GRPOTrainer

def get_world_size():
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1 # fallback for non-distributed run

def sanitize_for_wandb(name):
return name.replace("/", "-").replace(":", "-")

def get_args():
parser = argparse.ArgumentParser(description="Train with GRPO")

parser.add_argument(
    "--model_name",
    type=str,
    default="Qwen/Qwen2-0.5B-Instruct",
    required=True,
    help="Model name or path",
)
parser.add_argument(
    "--ckpt_path", type=str, default=None, help="Optional checkpoint to resume from"
)
parser.add_argument(
    "--rewards",
    type=str,
    default="efficient_reasoning",
    help="Comma-separated reward functions ",
)
# add reward weights argument
parser.add_argument(
    "--reward_weights",
    type=str,
    default=None,
    help="Comma-separated weights for each reward function",
)
parser.add_argument(
    "--output_dir",
    type=str,
    required=True,
    default="output_token_space_efficient_reasoning",
    help="Output directory for training artifacts",
)
parser.add_argument(
    "--dataset_name",
    type=str,
    default=None,
    help="Dataset name (e.g., gsm8k, trl-lib/tldr)",
)
parser.add_argument(
    "--dataset_split",
    type=int,
    default=1024,
    help="Subset of training data to use (optional)",
)
parser.add_argument(
    "--use_peft",
    action="store_true",
    help="only train Lora",
)
parser.add_argument(
    "--use_quant",
    action="store_true",
    help="load base model in 4/8 bits",
)
parser.add_argument(
    "--use_fsdp",
    action="store_true",
    help="Use Fully Sharded Data Parallel (FSDP) for distributed training",
)
parser.add_argument(
    "--use_attention",
    type=str,
    default="flash_attention_2",
    help="Attention implementation to use (e.g., flash_attention_2, default)",
)

parser.add_argument(
    "--max_context_length",
    type=int,
    default=1024,
    help="Maximum context length",
)
parser.add_argument(
    "--dataset_difficulty",
    type=str,
    default=None,
    help="Difficulty level of the dataset (e.g., easy, medium, hard)",
)

parser.add_argument(
    "--expt_tag",
    type=str,
    default="",
    help="Experiment tag for wandb",
)
parser.add_argument(
    "--per_device_train_batch_size",
    type=int,
    default=4,
    help="Batch size per device for training",
)
parser.add_argument(
    "--use_unsloth",
    action="store_true",
    help="Use Unsloth for fast inference",
)

return parser.parse_args()

def main():
args = get_args()
init_distributed()
torch.backends.cuda.enable_mem_efficient_sdp(True)

model = get_model(
    args.model_name,
    use_peft=args.use_peft,
    use_quant=args.use_quant,
    use_attention=args.use_attention,
    ckpt_path=args.ckpt_path,
    use_unsloth=args.use_unsloth,
)

train_dataset, val_datasets, eval_datasets = get_grpo_dataset(
    dataset_name=args.dataset_name,
    n_shot=0,  # Set to 0 for zero-shot training
)

rank0_print(f"Using datasets {train_dataset} and {val_datasets}")

# if torch.distributed.get_rank() == 0:
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
    project_name = sanitize_for_wandb(
        f"{args.expt_tag}_GRPO_{'PEFT' if args.use_peft else ''}_{args.model_name}_{args.dataset_name if args.dataset_name else args.dataset_difficulty}"
    )
    wandb.init(
        project=project_name,  # change to your wandb project
        name=args.expt_tag + args.output_dir,  # optional run name
    )

per_device_train_batch_size = args.per_device_train_batch_size
num_processes = get_world_size()
gradient_accumulation_steps = 1

effective_batch_size = (
    per_device_train_batch_size * num_processes * gradient_accumulation_steps
)

for name, param in model.named_parameters():
    if param.requires_grad:
        rank0_print(f"Trainable parameter: {name}, shape: {param.shape}")

fsdp_config = dict(
    fsdp="full_shard auto_wrap",
    auto_wrap_policy=custom_wrap_policy_qwen,
)
training_args = GRPOConfig(
    output_dir=args.output_dir,
    logging_steps=10,
    num_train_epochs=3,
    per_device_train_batch_size=per_device_train_batch_size,
    save_steps=100,
    report_to="wandb",
    eval_steps=1000 if val_datasets else None,
    eval_strategy="steps" if val_datasets else "no",
    num_generations=effective_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=per_device_train_batch_size,
    max_completion_length=args.max_context_length,
    fsdp_config=fsdp_config if args.use_fsdp else None,
    # bf16=True,  # Enable bf16 if supported
    # fp16=True,  # Enable fp16 for training
)

rewards = [REWARD_FUNC_MAP[reward] for reward in args.rewards.split(",")]
if args.reward_weights:
    reward_weights = list(map(float, args.reward_weights.split(",")))
    training_args.reward_weights = reward_weights

# model.gradient_checkpointing_enable()
torch.backends.cuda.enable_mem_efficient_sdp(True)
trainer = GRPOTrainer(
    model=model,
    reward_funcs=rewards,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_datasets[0] if val_datasets else None,
)
trainer.train()

if name == "main":
main()

Expected behavior

specifying attention implementation as pages attention should not change anything.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions