|
10 | 10 | from http import HTTPStatus |
11 | 11 | from typing import Any, ClassVar, Generic, TypeAlias, TypeVar |
12 | 12 |
|
| 13 | +import numpy as np |
13 | 14 | import torch |
14 | 15 | from fastapi import Request |
15 | 16 | from pydantic import BaseModel, ConfigDict, Field, TypeAdapter |
@@ -389,8 +390,9 @@ async def beam_search( |
389 | 390 |
|
390 | 391 | sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) |
391 | 392 |
|
| 393 | + logprobs_num = 2 * beam_width |
392 | 394 | beam_search_params = SamplingParams( |
393 | | - logprobs=2 * beam_width, |
| 395 | + logprobs=logprobs_num, |
394 | 396 | max_tokens=1, |
395 | 397 | temperature=temperature, |
396 | 398 | ) |
@@ -443,40 +445,75 @@ async def beam_search( |
443 | 445 | output = [x[0] for x in await asyncio.gather(*tasks)] |
444 | 446 |
|
445 | 447 | new_beams = [] |
446 | | - for i, current_beam in enumerate(all_beams): |
447 | | - result = output[i] |
448 | | - |
| 448 | + # Store all new tokens generated by beam |
| 449 | + all_beams_token_id = [] |
| 450 | + # Store the cumulative probability of all tokens |
| 451 | + # generated by beam search |
| 452 | + all_beams_logprob = [] |
| 453 | + # Iterate through all beam inference results |
| 454 | + for i, result in enumerate(output): |
| 455 | + current_beam = all_beams[i] |
449 | 456 | if result.outputs[0].logprobs is not None: |
450 | 457 | logprobs = result.outputs[0].logprobs[0] |
451 | | - for token_id, logprob_obj in logprobs.items(): |
452 | | - if token_id == eos_token_id and not ignore_eos: |
453 | | - completed.append( |
454 | | - BeamSearchSequence( |
455 | | - tokens=current_beam.tokens + [token_id] |
456 | | - if include_stop_str_in_output |
457 | | - else current_beam.tokens, |
458 | | - logprobs=current_beam.logprobs + [logprobs], |
459 | | - cum_logprob=current_beam.cum_logprob |
460 | | - + logprob_obj.logprob, |
461 | | - finish_reason="stop", |
462 | | - stop_reason=eos_token_id, |
463 | | - ) |
464 | | - ) |
465 | | - else: |
466 | | - new_beams.append( |
467 | | - BeamSearchSequence( |
468 | | - tokens=current_beam.tokens + [token_id], |
469 | | - logprobs=current_beam.logprobs + [logprobs], |
470 | | - lora_request=current_beam.lora_request, |
471 | | - cum_logprob=current_beam.cum_logprob |
472 | | - + logprob_obj.logprob, |
473 | | - multi_modal_data=current_beam.multi_modal_data, |
474 | | - mm_processor_kwargs=current_beam.mm_processor_kwargs, |
475 | | - ) |
476 | | - ) |
477 | | - |
478 | | - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) |
479 | | - all_beams = sorted_beams[:beam_width] |
| 458 | + all_beams_token_id.extend(list(logprobs.keys())) |
| 459 | + all_beams_logprob.extend( |
| 460 | + [ |
| 461 | + current_beam.cum_logprob + obj.logprob |
| 462 | + for obj in logprobs.values() |
| 463 | + ] |
| 464 | + ) |
| 465 | + |
| 466 | + # Handle the token for the end of sentence (EOS) |
| 467 | + all_beams_token_id = np.array(all_beams_token_id) |
| 468 | + all_beams_logprob = np.array(all_beams_logprob) |
| 469 | + |
| 470 | + if not ignore_eos: |
| 471 | + # Get the index position of eos token in all generated results |
| 472 | + eos_idx = np.where(all_beams_token_id == eos_token_id)[0] |
| 473 | + for idx in eos_idx: |
| 474 | + current_beam = all_beams[idx // logprobs_num] |
| 475 | + result = output[idx // logprobs_num] |
| 476 | + assert result.outputs[0].logprobs is not None |
| 477 | + logprobs_entry = result.outputs[0].logprobs[0] |
| 478 | + completed.append( |
| 479 | + BeamSearchSequence( |
| 480 | + tokens=current_beam.tokens + [eos_token_id] |
| 481 | + if include_stop_str_in_output |
| 482 | + else current_beam.tokens, |
| 483 | + logprobs=current_beam.logprobs + [logprobs_entry], |
| 484 | + cum_logprob=float(all_beams_logprob[idx]), |
| 485 | + finish_reason="stop", |
| 486 | + stop_reason=eos_token_id, |
| 487 | + ) |
| 488 | + ) |
| 489 | + # After processing, set the log probability of the eos condition |
| 490 | + # to negative infinity. |
| 491 | + all_beams_logprob[eos_idx] = -np.inf |
| 492 | + |
| 493 | + # Processing non-EOS tokens |
| 494 | + # Get indices of the top beam_width probabilities |
| 495 | + topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ |
| 496 | + :beam_width |
| 497 | + ] |
| 498 | + |
| 499 | + for idx in topn_idx: |
| 500 | + current_beam = all_beams[idx // logprobs_num] |
| 501 | + result = output[idx // logprobs_num] |
| 502 | + token_id = int(all_beams_token_id[idx]) |
| 503 | + assert result.outputs[0].logprobs is not None |
| 504 | + logprobs_entry = result.outputs[0].logprobs[0] |
| 505 | + new_beams.append( |
| 506 | + BeamSearchSequence( |
| 507 | + tokens=current_beam.tokens + [token_id], |
| 508 | + logprobs=current_beam.logprobs + [logprobs_entry], |
| 509 | + lora_request=current_beam.lora_request, |
| 510 | + cum_logprob=float(all_beams_logprob[idx]), |
| 511 | + multi_modal_data=current_beam.multi_modal_data, |
| 512 | + mm_processor_kwargs=current_beam.mm_processor_kwargs, |
| 513 | + ) |
| 514 | + ) |
| 515 | + |
| 516 | + all_beams = new_beams |
480 | 517 |
|
481 | 518 | completed.extend(all_beams) |
482 | 519 | sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) |
|
0 commit comments