55import pytest
66
77from vllm .logprobs import (
8- FlattenLogprobs ,
8+ FlatLogprobs ,
99 Logprob ,
1010 LogprobsOnePosition ,
1111 append_logprobs_for_next_position ,
1414)
1515
1616
17- def test_create_logprobs_non_flatten (monkeypatch : pytest .MonkeyPatch ) -> None :
18- monkeypatch .setenv ("VLLM_FLATTEN_LOGPROBS " , "0" )
17+ def test_create_logprobs_non_flat (monkeypatch : pytest .MonkeyPatch ) -> None :
18+ monkeypatch .setenv ("VLLM_FLAT_LOGPROBS " , "0" )
1919
2020 prompt_logprobs = create_prompt_logprobs ()
2121 assert isinstance (prompt_logprobs , list )
@@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
2828 assert len (sample_logprobs ) == 0
2929
3030
31- def test_create_logprobs_flatten (monkeypatch : pytest .MonkeyPatch ) -> None :
32- monkeypatch .setenv ("VLLM_FLATTEN_LOGPROBS " , "1" )
31+ def test_create_logprobs_flat (monkeypatch : pytest .MonkeyPatch ) -> None :
32+ monkeypatch .setenv ("VLLM_FLAT_LOGPROBS " , "1" )
3333
3434 prompt_logprobs = create_prompt_logprobs ()
35- assert isinstance (prompt_logprobs , FlattenLogprobs )
35+ assert isinstance (prompt_logprobs , FlatLogprobs )
3636 assert prompt_logprobs .start_indices == [0 ]
3737 assert prompt_logprobs .end_indices == [0 ]
3838 assert len (prompt_logprobs .token_ids ) == 0
@@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
4444 assert prompt_logprobs [0 ] == dict ()
4545
4646 sample_logprobs = create_sample_logprobs ()
47- assert isinstance (sample_logprobs , FlattenLogprobs )
47+ assert isinstance (sample_logprobs , FlatLogprobs )
4848 assert len (sample_logprobs .start_indices ) == 0
4949 assert len (sample_logprobs .end_indices ) == 0
5050 assert len (sample_logprobs .token_ids ) == 0
@@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
5454 assert len (sample_logprobs ) == 0
5555
5656
57- def test_append_logprobs_for_next_position_none_flatten (
57+ def test_append_logprobs_for_next_position_none_flat (
5858 monkeypatch : pytest .MonkeyPatch ,
5959) -> None :
60- monkeypatch .setenv ("VLLM_FLATTEN_LOGPROBS " , "0" )
60+ monkeypatch .setenv ("VLLM_FLAT_LOGPROBS " , "0" )
6161 logprobs = create_sample_logprobs ()
6262 append_logprobs_for_next_position (
6363 logprobs ,
@@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
8585 ]
8686
8787
88- def test_append_logprobs_for_next_position_flatten (
88+ def test_append_logprobs_for_next_position_flat (
8989 monkeypatch : pytest .MonkeyPatch ,
9090) -> None :
91- monkeypatch .setenv ("VLLM_FLATTEN_LOGPROBS " , "1" )
91+ monkeypatch .setenv ("VLLM_FLAT_LOGPROBS " , "1" )
9292 logprobs = create_sample_logprobs ()
9393 append_logprobs_for_next_position (
9494 logprobs ,
@@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
106106 rank = 11 ,
107107 num_logprobs = - 1 ,
108108 )
109- assert isinstance (logprobs , FlattenLogprobs )
109+ assert isinstance (logprobs , FlatLogprobs )
110110 assert logprobs .start_indices == [0 , 1 ]
111111 assert logprobs .end_indices == [1 , 3 ]
112112 assert logprobs .token_ids == [1 , 2 , 3 ]
@@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten(
129129}
130130
131131
132- def test_flatten_logprobs_append () -> None :
133- logprobs = FlattenLogprobs ()
132+ def test_flat_logprobs_append () -> None :
133+ logprobs = FlatLogprobs ()
134134 logprobs .append (LOGPROBS_ONE_POSITION_0 )
135135 logprobs .append (LOGPROBS_ONE_POSITION_1 )
136136 assert logprobs .start_indices == [0 , 1 ]
@@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
149149 assert logprobs .decoded_tokens == ["10" , "20" , "30" , "40" , "50" , "60" ]
150150
151151
152- def test_flatten_logprobs_extend () -> None :
153- logprobs = FlattenLogprobs ()
152+ def test_flat_logprobs_extend () -> None :
153+ logprobs = FlatLogprobs ()
154154 # Extend with list[LogprobsOnePosition]
155155 logprobs .extend ([LOGPROBS_ONE_POSITION_2 , LOGPROBS_ONE_POSITION_0 ])
156156 assert logprobs .start_indices == [0 , 3 ]
@@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
160160 assert logprobs .ranks == [40 , 50 , 60 , 10 ]
161161 assert logprobs .decoded_tokens == ["40" , "50" , "60" , "10" ]
162162
163- other_logprobs = FlattenLogprobs ()
163+ other_logprobs = FlatLogprobs ()
164164 other_logprobs .extend ([LOGPROBS_ONE_POSITION_1 , LOGPROBS_ONE_POSITION_0 ])
165- # Extend with another FlattenLogprobs
165+ # Extend with another FlatLogprobs
166166 logprobs .extend (other_logprobs )
167167 assert logprobs .start_indices == [0 , 3 , 4 , 6 ]
168168 assert logprobs .end_indices == [3 , 4 , 6 , 7 ]
@@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
172172 assert logprobs .decoded_tokens == ["40" , "50" , "60" , "10" , "20" , "30" , "10" ]
173173
174174
175- def test_flatten_logprobs_access () -> None :
176- logprobs = FlattenLogprobs ()
175+ def test_flat_logprobs_access () -> None :
176+ logprobs = FlatLogprobs ()
177177 logprobs .extend (
178178 [LOGPROBS_ONE_POSITION_1 , LOGPROBS_ONE_POSITION_2 , LOGPROBS_ONE_POSITION_0 ]
179179 )
0 commit comments