Skip to content

Commit 27c3409

Browse files
zhuohan123xuebwang-amd
authored andcommitted
[Misc] FlattenLogprobs -> FlatLogprobs (vllm-project#28335)
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 541ce7a commit 27c3409

File tree

4 files changed

+43
-47
lines changed

4 files changed

+43
-47
lines changed

tests/samplers/test_logprobs.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from vllm import SamplingParams
7-
from vllm.logprobs import FlattenLogprobs
7+
from vllm.logprobs import FlatLogprobs
88

99
MODELS = ["distilbert/distilgpt2"]
1010
MAX_TOKENS = 5
@@ -16,17 +16,17 @@
1616
@pytest.mark.parametrize("model", MODELS)
1717
@pytest.mark.parametrize("dtype", ["half"])
1818
@pytest.mark.parametrize("greedy", [True, False])
19-
@pytest.mark.parametrize("flatten_logprobs", [True, False])
19+
@pytest.mark.parametrize("flat_logprobs", [True, False])
2020
def test_ranks(
2121
vllm_runner,
2222
model,
2323
dtype,
2424
greedy,
25-
flatten_logprobs,
25+
flat_logprobs,
2626
example_prompts,
2727
monkeypatch: pytest.MonkeyPatch,
2828
):
29-
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
29+
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
3030
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
3131
tokenizer = vllm_model.llm.get_tokenizer()
3232
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
@@ -44,12 +44,8 @@ def test_ranks(
4444
decode_tokens, _, decode_logprobs, prompt_logprobs = result
4545

4646
# Ensure the return type of logprobs is accurate
47-
assert isinstance(
48-
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
49-
)
50-
assert isinstance(
51-
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
52-
)
47+
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
48+
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
5349

5450
########################
5551
# Check prompt logprobs

tests/test_logprobs.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from vllm.logprobs import (
8-
FlattenLogprobs,
8+
FlatLogprobs,
99
Logprob,
1010
LogprobsOnePosition,
1111
append_logprobs_for_next_position,
@@ -14,8 +14,8 @@
1414
)
1515

1616

17-
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
18-
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
17+
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
18+
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
1919

2020
prompt_logprobs = create_prompt_logprobs()
2121
assert isinstance(prompt_logprobs, list)
@@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
2828
assert len(sample_logprobs) == 0
2929

3030

31-
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
32-
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
31+
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
32+
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
3333

3434
prompt_logprobs = create_prompt_logprobs()
35-
assert isinstance(prompt_logprobs, FlattenLogprobs)
35+
assert isinstance(prompt_logprobs, FlatLogprobs)
3636
assert prompt_logprobs.start_indices == [0]
3737
assert prompt_logprobs.end_indices == [0]
3838
assert len(prompt_logprobs.token_ids) == 0
@@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
4444
assert prompt_logprobs[0] == dict()
4545

4646
sample_logprobs = create_sample_logprobs()
47-
assert isinstance(sample_logprobs, FlattenLogprobs)
47+
assert isinstance(sample_logprobs, FlatLogprobs)
4848
assert len(sample_logprobs.start_indices) == 0
4949
assert len(sample_logprobs.end_indices) == 0
5050
assert len(sample_logprobs.token_ids) == 0
@@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
5454
assert len(sample_logprobs) == 0
5555

5656

57-
def test_append_logprobs_for_next_position_none_flatten(
57+
def test_append_logprobs_for_next_position_none_flat(
5858
monkeypatch: pytest.MonkeyPatch,
5959
) -> None:
60-
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
60+
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
6161
logprobs = create_sample_logprobs()
6262
append_logprobs_for_next_position(
6363
logprobs,
@@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
8585
]
8686

8787

88-
def test_append_logprobs_for_next_position_flatten(
88+
def test_append_logprobs_for_next_position_flat(
8989
monkeypatch: pytest.MonkeyPatch,
9090
) -> None:
91-
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
91+
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
9292
logprobs = create_sample_logprobs()
9393
append_logprobs_for_next_position(
9494
logprobs,
@@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
106106
rank=11,
107107
num_logprobs=-1,
108108
)
109-
assert isinstance(logprobs, FlattenLogprobs)
109+
assert isinstance(logprobs, FlatLogprobs)
110110
assert logprobs.start_indices == [0, 1]
111111
assert logprobs.end_indices == [1, 3]
112112
assert logprobs.token_ids == [1, 2, 3]
@@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten(
129129
}
130130

131131

132-
def test_flatten_logprobs_append() -> None:
133-
logprobs = FlattenLogprobs()
132+
def test_flat_logprobs_append() -> None:
133+
logprobs = FlatLogprobs()
134134
logprobs.append(LOGPROBS_ONE_POSITION_0)
135135
logprobs.append(LOGPROBS_ONE_POSITION_1)
136136
assert logprobs.start_indices == [0, 1]
@@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
149149
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
150150

151151

152-
def test_flatten_logprobs_extend() -> None:
153-
logprobs = FlattenLogprobs()
152+
def test_flat_logprobs_extend() -> None:
153+
logprobs = FlatLogprobs()
154154
# Extend with list[LogprobsOnePosition]
155155
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
156156
assert logprobs.start_indices == [0, 3]
@@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
160160
assert logprobs.ranks == [40, 50, 60, 10]
161161
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
162162

163-
other_logprobs = FlattenLogprobs()
163+
other_logprobs = FlatLogprobs()
164164
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
165-
# Extend with another FlattenLogprobs
165+
# Extend with another FlatLogprobs
166166
logprobs.extend(other_logprobs)
167167
assert logprobs.start_indices == [0, 3, 4, 6]
168168
assert logprobs.end_indices == [3, 4, 6, 7]
@@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
172172
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
173173

174174

175-
def test_flatten_logprobs_access() -> None:
176-
logprobs = FlattenLogprobs()
175+
def test_flat_logprobs_access() -> None:
176+
logprobs = FlatLogprobs()
177177
logprobs.extend(
178178
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
179179
)

vllm/envs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@
223223
VLLM_GC_DEBUG: str = ""
224224
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
225225
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
226-
VLLM_FLATTEN_LOGPROBS: bool = False
226+
VLLM_FLAT_LOGPROBS: bool = False
227227

228228

229229
def get_default_cache_root():
@@ -1481,11 +1481,11 @@ def get_vllm_port() -> int | None:
14811481
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
14821482
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
14831483
),
1484-
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
1484+
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
14851485
# the original list[dict[int, Logprob]] approach.
14861486
# After enabled, PromptLogprobs and SampleLogprobs would populated as
1487-
# FlattenLogprobs.
1488-
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
1487+
# FlatLogprobs.
1488+
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
14891489
}
14901490

14911491
# --8<-- [end:env-vars-definition]

vllm/logprobs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ class Logprob:
3030

3131

3232
@dataclass
33-
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
33+
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
3434
"""
35-
Flatten logprobs of a request into multiple primitive type lists.
35+
Flat logprobs of a request into multiple primitive type lists.
3636
3737
Compared to list[dict[int, Logprob]], this data structure reduced GC
3838
overhead significantly. As it flattened logprob information for
3939
all positions and ranks in to multiple primitive type lists (i.e.
4040
logprobs, token_ids, ranks per token_ids, decoded_tokens).
4141
So regardless of the sequence length and top_logprobs setup,
42-
FlattenLogprobs would only introduce a constant amount of objects.
42+
FlatLogprobs would only introduce a constant amount of objects.
4343
4444
As each position might contains different amount of ranks,
4545
start_indices_per_position would be used to access the logprob ranges
@@ -107,7 +107,7 @@ def __len__(self) -> int:
107107
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
108108

109109
@overload
110-
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
110+
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
111111

112112
def __getitem__(self, index: int | slice):
113113
"""Extracts logprobs of a given position or slice"""
@@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice):
123123
elif isinstance(index, slice):
124124
min_index = self.start_indices[index][0]
125125
max_index = self.end_indices[index][-1]
126-
return FlattenLogprobs(
126+
return FlatLogprobs(
127127
# Shift updated start_indices and end_indices to
128128
# be 0-indexed
129129
start_indices=[i - min_index for i in self.start_indices[index]],
@@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice):
137137
raise TypeError(f"Invalid index type: {type(index)}")
138138

139139
def __setitem__(self, item, value) -> None:
140-
raise TypeError("Cannot set logprobs in FlattenLogprobs")
140+
raise TypeError("Cannot set logprobs in FlatLogprobs")
141141

142142
def __delitem__(self, item) -> None:
143-
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
143+
raise TypeError("Cannot delete logprobs from FlatLogprobs")
144144

145145
def insert(self, item) -> None:
146-
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
146+
raise TypeError("Cannot insert logprobs to FlatLogprobs")
147147

148148
def __iter__(self) -> Iterator[LogprobsOnePosition]:
149149
"""
@@ -156,22 +156,22 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]:
156156

157157
# {token_id -> logprob} per each sequence group. None if the corresponding
158158
# sequence group doesn't require prompt logprob.
159-
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
159+
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
160160
# {token_id -> logprob} for each sequence group.
161-
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
161+
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
162162

163163

164164
def create_prompt_logprobs() -> PromptLogprobs:
165165
"""Creates a container to store prompt logprobs for a request"""
166-
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
166+
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
167167
# NOTE: logprob of first prompt token is None.
168168
logprobs.append(None)
169169
return logprobs
170170

171171

172172
def create_sample_logprobs() -> SampleLogprobs:
173173
"""Creates a container to store decode logprobs for a request"""
174-
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
174+
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
175175

176176

177177
def append_logprobs_for_next_position(
@@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
191191
topk_ranks = range(1, num_logprobs + 1)
192192
ranks = itertools.chain((rank,), topk_ranks)
193193

194-
if isinstance(request_logprobs, FlattenLogprobs):
194+
if isinstance(request_logprobs, FlatLogprobs):
195195
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
196196
else:
197197
request_logprobs.append(

0 commit comments

Comments
 (0)