Skip to content

Commit 4a85394

Browse files
tjtanaapaulpak58
authored andcommitted
[FEAT] Refactor ROPE into module (vllm-project#22192)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
1 parent ec7eeff commit 4a85394

15 files changed

+2111
-1967
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 0 additions & 1967 deletions
This file was deleted.
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Rotary Positional Embeddings."""
4+
from typing import Any, Optional
5+
6+
import torch
7+
8+
from .base import RotaryEmbedding
9+
from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
10+
from .dual_chunk_rope import DualChunkRotaryEmbedding
11+
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
12+
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
13+
from .linear_scaling_rope import LinearScalingRotaryEmbedding
14+
from .llama3_rope import Llama3RotaryEmbedding
15+
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
16+
from .mrope import MRotaryEmbedding
17+
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
18+
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
19+
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
20+
21+
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
22+
23+
24+
def get_rope(
25+
head_size: int,
26+
rotary_dim: int,
27+
max_position: int,
28+
base: float,
29+
is_neox_style: bool = True,
30+
rope_scaling: Optional[dict[str, Any]] = None,
31+
dtype: Optional[torch.dtype] = None,
32+
partial_rotary_factor: float = 1.0,
33+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
34+
) -> RotaryEmbedding:
35+
if dtype is None:
36+
dtype = torch.get_default_dtype()
37+
if rope_scaling is not None:
38+
# Transforms every value that is a list into a tuple for caching calls
39+
rope_scaling_tuple = {
40+
k: tuple(v) if isinstance(v, list) else v
41+
for k, v in rope_scaling.items()
42+
}
43+
rope_scaling_args = tuple(rope_scaling_tuple.items())
44+
else:
45+
rope_scaling_args = None
46+
47+
if dual_chunk_attention_config is not None:
48+
dual_chunk_attention_tuple = {
49+
k: tuple(v) if isinstance(v, list) else v
50+
for k, v in dual_chunk_attention_config.items()
51+
if k != "sparse_attention_config"
52+
}
53+
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
54+
else:
55+
dual_chunk_attention_args = None
56+
57+
if partial_rotary_factor < 1.0:
58+
rotary_dim = int(rotary_dim * partial_rotary_factor)
59+
key = (head_size, rotary_dim, max_position, base, is_neox_style,
60+
rope_scaling_args, dual_chunk_attention_args, dtype)
61+
if key in _ROPE_DICT:
62+
return _ROPE_DICT[key]
63+
64+
if dual_chunk_attention_config is not None:
65+
extra_kwargs = {
66+
k: v
67+
for k, v in dual_chunk_attention_config.items()
68+
if k in ("chunk_size", "local_size")
69+
}
70+
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
71+
max_position, base,
72+
is_neox_style, dtype,
73+
**extra_kwargs)
74+
elif not rope_scaling:
75+
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
76+
is_neox_style, dtype)
77+
else:
78+
scaling_type = rope_scaling["rope_type"]
79+
80+
if scaling_type == "llama3":
81+
scaling_factor = rope_scaling["factor"]
82+
low_freq_factor = rope_scaling["low_freq_factor"]
83+
high_freq_factor = rope_scaling["high_freq_factor"]
84+
original_max_position = rope_scaling[
85+
"original_max_position_embeddings"]
86+
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
87+
max_position, base,
88+
is_neox_style, dtype,
89+
scaling_factor, low_freq_factor,
90+
high_freq_factor,
91+
original_max_position)
92+
elif scaling_type == "mllama4":
93+
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
94+
max_position, base,
95+
is_neox_style, dtype)
96+
elif scaling_type == "default":
97+
if "mrope_section" in rope_scaling:
98+
rotary_emb = MRotaryEmbedding(
99+
head_size,
100+
rotary_dim,
101+
max_position,
102+
base,
103+
is_neox_style,
104+
dtype,
105+
mrope_section=rope_scaling["mrope_section"],
106+
)
107+
else:
108+
rotary_emb = RotaryEmbedding(
109+
head_size,
110+
rotary_dim,
111+
max_position,
112+
base,
113+
is_neox_style,
114+
dtype,
115+
)
116+
elif scaling_type == "linear":
117+
scaling_factor = rope_scaling["factor"]
118+
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
119+
max_position, base,
120+
is_neox_style,
121+
scaling_factor, dtype)
122+
elif scaling_type == "ntk":
123+
scaling_factor = rope_scaling["factor"]
124+
mixed_b = rope_scaling.get('mixed_b', None)
125+
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
126+
max_position, base,
127+
is_neox_style,
128+
scaling_factor, dtype,
129+
mixed_b)
130+
elif scaling_type == "dynamic":
131+
if "alpha" in rope_scaling:
132+
scaling_alpha = rope_scaling["alpha"]
133+
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
134+
head_size, rotary_dim, max_position, base, is_neox_style,
135+
scaling_alpha, dtype)
136+
elif "factor" in rope_scaling:
137+
scaling_factor = rope_scaling["factor"]
138+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
139+
head_size, rotary_dim, max_position, base, is_neox_style,
140+
scaling_factor, dtype)
141+
else:
142+
raise ValueError("Dynamic rope scaling must contain either "
143+
"'alpha' or 'factor' field")
144+
elif scaling_type == "yarn":
145+
scaling_factor = rope_scaling["factor"]
146+
original_max_position = rope_scaling[
147+
"original_max_position_embeddings"]
148+
extra_kwargs = {
149+
k: v
150+
for k, v in rope_scaling.items()
151+
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
152+
"beta_slow")
153+
}
154+
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
155+
original_max_position,
156+
base, is_neox_style,
157+
scaling_factor, dtype,
158+
**extra_kwargs)
159+
elif scaling_type == "deepseek_yarn":
160+
scaling_factor = rope_scaling["factor"]
161+
original_max_position = rope_scaling[
162+
"original_max_position_embeddings"]
163+
# assert max_position == original_max_position * scaling_factor
164+
extra_kwargs = {
165+
k: v
166+
for k, v in rope_scaling.items()
167+
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
168+
"beta_slow", "mscale", "mscale_all_dim")
169+
}
170+
rotary_emb = DeepseekScalingRotaryEmbedding(
171+
head_size, rotary_dim, original_max_position, base,
172+
is_neox_style, scaling_factor, dtype, **extra_kwargs)
173+
elif scaling_type == "longrope":
174+
short_factor = rope_scaling["short_factor"]
175+
long_factor = rope_scaling["long_factor"]
176+
original_max_position = rope_scaling[
177+
"original_max_position_embeddings"]
178+
extra_kwargs = {
179+
k: v
180+
for k, v in rope_scaling.items()
181+
if k in ("short_mscale", "long_mscale")
182+
}
183+
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
184+
head_size, rotary_dim, max_position, original_max_position,
185+
base, is_neox_style, dtype, short_factor, long_factor,
186+
**extra_kwargs)
187+
else:
188+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
189+
_ROPE_DICT[key] = rotary_emb
190+
return rotary_emb

0 commit comments

Comments
 (0)