Skip to content

Commit 6e6c1e4

Browse files
zhanggzhzhangguozhumgoin
authored andcommitted
[Frontend] Optimize beam search loop by sorting and then splicing (vllm-project#19347)
Signed-off-by: zhangguozhu <zhangguozhu@360.cn> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: zhangguozhu <zhangguozhu@360.cn> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
1 parent 5c21e2e commit 6e6c1e4

File tree

1 file changed

+70
-33
lines changed

1 file changed

+70
-33
lines changed

vllm/entrypoints/openai/serving_engine.py

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from http import HTTPStatus
1111
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
1212

13+
import numpy as np
1314
import torch
1415
from fastapi import Request
1516
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
@@ -389,8 +390,9 @@ async def beam_search(
389390

390391
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
391392

393+
logprobs_num = 2 * beam_width
392394
beam_search_params = SamplingParams(
393-
logprobs=2 * beam_width,
395+
logprobs=logprobs_num,
394396
max_tokens=1,
395397
temperature=temperature,
396398
)
@@ -443,40 +445,75 @@ async def beam_search(
443445
output = [x[0] for x in await asyncio.gather(*tasks)]
444446

445447
new_beams = []
446-
for i, current_beam in enumerate(all_beams):
447-
result = output[i]
448-
448+
# Store all new tokens generated by beam
449+
all_beams_token_id = []
450+
# Store the cumulative probability of all tokens
451+
# generated by beam search
452+
all_beams_logprob = []
453+
# Iterate through all beam inference results
454+
for i, result in enumerate(output):
455+
current_beam = all_beams[i]
449456
if result.outputs[0].logprobs is not None:
450457
logprobs = result.outputs[0].logprobs[0]
451-
for token_id, logprob_obj in logprobs.items():
452-
if token_id == eos_token_id and not ignore_eos:
453-
completed.append(
454-
BeamSearchSequence(
455-
tokens=current_beam.tokens + [token_id]
456-
if include_stop_str_in_output
457-
else current_beam.tokens,
458-
logprobs=current_beam.logprobs + [logprobs],
459-
cum_logprob=current_beam.cum_logprob
460-
+ logprob_obj.logprob,
461-
finish_reason="stop",
462-
stop_reason=eos_token_id,
463-
)
464-
)
465-
else:
466-
new_beams.append(
467-
BeamSearchSequence(
468-
tokens=current_beam.tokens + [token_id],
469-
logprobs=current_beam.logprobs + [logprobs],
470-
lora_request=current_beam.lora_request,
471-
cum_logprob=current_beam.cum_logprob
472-
+ logprob_obj.logprob,
473-
multi_modal_data=current_beam.multi_modal_data,
474-
mm_processor_kwargs=current_beam.mm_processor_kwargs,
475-
)
476-
)
477-
478-
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
479-
all_beams = sorted_beams[:beam_width]
458+
all_beams_token_id.extend(list(logprobs.keys()))
459+
all_beams_logprob.extend(
460+
[
461+
current_beam.cum_logprob + obj.logprob
462+
for obj in logprobs.values()
463+
]
464+
)
465+
466+
# Handle the token for the end of sentence (EOS)
467+
all_beams_token_id = np.array(all_beams_token_id)
468+
all_beams_logprob = np.array(all_beams_logprob)
469+
470+
if not ignore_eos:
471+
# Get the index position of eos token in all generated results
472+
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
473+
for idx in eos_idx:
474+
current_beam = all_beams[idx // logprobs_num]
475+
result = output[idx // logprobs_num]
476+
assert result.outputs[0].logprobs is not None
477+
logprobs_entry = result.outputs[0].logprobs[0]
478+
completed.append(
479+
BeamSearchSequence(
480+
tokens=current_beam.tokens + [eos_token_id]
481+
if include_stop_str_in_output
482+
else current_beam.tokens,
483+
logprobs=current_beam.logprobs + [logprobs_entry],
484+
cum_logprob=float(all_beams_logprob[idx]),
485+
finish_reason="stop",
486+
stop_reason=eos_token_id,
487+
)
488+
)
489+
# After processing, set the log probability of the eos condition
490+
# to negative infinity.
491+
all_beams_logprob[eos_idx] = -np.inf
492+
493+
# Processing non-EOS tokens
494+
# Get indices of the top beam_width probabilities
495+
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
496+
:beam_width
497+
]
498+
499+
for idx in topn_idx:
500+
current_beam = all_beams[idx // logprobs_num]
501+
result = output[idx // logprobs_num]
502+
token_id = int(all_beams_token_id[idx])
503+
assert result.outputs[0].logprobs is not None
504+
logprobs_entry = result.outputs[0].logprobs[0]
505+
new_beams.append(
506+
BeamSearchSequence(
507+
tokens=current_beam.tokens + [token_id],
508+
logprobs=current_beam.logprobs + [logprobs_entry],
509+
lora_request=current_beam.lora_request,
510+
cum_logprob=float(all_beams_logprob[idx]),
511+
multi_modal_data=current_beam.multi_modal_data,
512+
mm_processor_kwargs=current_beam.mm_processor_kwargs,
513+
)
514+
)
515+
516+
all_beams = new_beams
480517

481518
completed.extend(all_beams)
482519
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)

0 commit comments

Comments
 (0)