Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 61 additions & 55 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext

logger = init_logger(__name__)

Expand Down Expand Up @@ -259,49 +260,52 @@ def schedule(self) -> SchedulerOutput:
continue

# Schedule newly needed KV blocks for the request.
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)

if new_blocks is not None:
# The request can be scheduled.
break

# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
else:
preempted_req = self.running.pop()

self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
)
if new_blocks is not None:
# The request can be scheduled.
break

self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[
preempted_req.request_id
]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
else:
preempted_req = self.running.pop()

self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
)

self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break

if new_blocks is None:
# Cannot schedule this request.
Expand Down Expand Up @@ -599,13 +603,14 @@ def schedule(self) -> SchedulerOutput:
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id
)
)
)

# Construct the scheduler output.
new_reqs_data = [
Expand All @@ -614,13 +619,14 @@ def schedule(self) -> SchedulerOutput:
)
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)

# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
Expand Down Expand Up @@ -649,8 +655,8 @@ def schedule(self) -> SchedulerOutput:
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta

self._update_after_schedule(scheduler_output)
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output

def _update_after_schedule(
Expand Down
117 changes: 73 additions & 44 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -315,17 +316,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)

engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("core step: schedule"):
scheduler_output = self.scheduler.schedule()

with record_function_or_nullcontext("core step: execute_model"):
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)

with record_function_or_nullcontext("core step: update_from_output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)

return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0

Expand Down Expand Up @@ -363,32 +368,49 @@ def step_with_batch_queue(
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
scheduler_output = self.scheduler.schedule()
with record_function_or_nullcontext(
"core step_with_batch_queue: execute_model"
):
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0

if scheduler_output.pending_structured_output_tokens:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
with record_function_or_nullcontext(
"core step_with_batch_queue: pending_structured_output_tokens"
):
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return
# (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with record_function_or_nullcontext(
"core step_with_batch_queue: get_grammar_bitmask"
):
# We aren't waiting for any tokens, get any grammar
# output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()

if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
with record_function_or_nullcontext(
"core step_with_batch_queue: sample_tokens"
):
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
Expand All @@ -408,27 +430,34 @@ def step_with_batch_queue(
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False

# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()

engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
with record_function_or_nullcontext(
"core step_with_batch_queue: update_from_output"
):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)

# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
with record_function_or_nullcontext(
"core step_with_batch_queue: deferred_scheduler_output"
):
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
batch_queue.appendleft((future, deferred_scheduler_output))

return engine_core_outputs, model_executed

Expand Down
37 changes: 21 additions & 16 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.worker_base import WorkerBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -280,28 +281,32 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]:
return []

# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
with record_function_or_nullcontext("llm_genine step: get_output"):
outputs = self.engine_core.get_output()

# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
with record_function_or_nullcontext("llm_genine step: process_outputs"):
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)

# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
with record_function_or_nullcontext("llm_genine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

# 4) Record stats
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
with record_function_or_nullcontext("llm_genine step: record_stats"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo here and above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in #28584

if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()

return processed_outputs.request_outputs

Expand Down
Loading