Skip to content

Commit 60532bf

Browse files
committed
avoid kv cache manager race conditions
Signed-off-by: Qiang Xu <qiangx@nvidia.com>
1 parent bed20f7 commit 60532bf

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def __init__(self,
214214
self.responses = {}
215215
self.result_wait_queues = {}
216216

217-
self.sm_disagg_request_lock = threading.Lock()
218-
self.ctx_request_cv = threading.Condition(self.sm_disagg_request_lock)
219-
self.gen_request_cv = threading.Condition(self.sm_disagg_request_lock)
217+
self.sm_disagg_lock = threading.Lock()
218+
self.ctx_request_cv = threading.Condition(self.sm_disagg_lock)
219+
self.gen_request_cv = threading.Condition(self.sm_disagg_lock)
220220

221221
# kv cache events
222222
self.kv_cache_manager = self.resource_manager.resource_managers.get(
@@ -1536,7 +1536,7 @@ def _executor_loop_sm_disagg_ctx(self, stream):
15361536
self.executor_request_queue.
15371537
get_new_active_requests_queue_latency())
15381538

1539-
with self.sm_disagg_request_lock:
1539+
with self.sm_disagg_lock:
15401540
ctx_requests = get_context_requests(self.active_requests)
15411541
if self.is_shutdown and len(ctx_requests) == 0 \
15421542
and self.executor_request_queue.get_waiting_queue_size() == 0:
@@ -1559,7 +1559,8 @@ def _executor_loop_sm_disagg_ctx(self, stream):
15591559

15601560
if scheduled_batch.batch_size > 0 or (
15611561
self.enable_attention_dp and self.dist.tp_size > 1):
1562-
self.resource_manager.prepare_resources(scheduled_batch)
1562+
with self.sm_disagg_lock:
1563+
self.resource_manager.prepare_resources(scheduled_batch)
15631564

15641565
with torch.cuda.stream(stream):
15651566
batch_outputs = self._forward_step(
@@ -1569,23 +1570,25 @@ def _executor_loop_sm_disagg_ctx(self, stream):
15691570
# To avoid long sync time in critical section below
15701571
sample_state.sampler_event.synchronize()
15711572

1572-
with self.sm_disagg_request_lock:
1573+
with self.sm_disagg_lock:
15731574
self._update_request_states(scheduled_batch)
15741575
self._update_requests(sample_state,
15751576
self.resource_manager)
15761577
self._handle_canceled_requests()
15771578
finished_requests = self._handle_responses()
1578-
self.ctx_request_cv.notify()
15791579

1580-
attn_metadata = getattr(self.ctx_model_engine,
1581-
'attn_metadata', None)
1582-
kv_cache_dtype_byte_size = getattr(
1583-
self.ctx_model_engine, 'kv_cache_dtype_byte_size', None)
1584-
self.resource_manager.update_resources(
1585-
scheduled_batch, attn_metadata,
1586-
kv_cache_dtype_byte_size)
1587-
if self.enable_kv_cache_events:
1588-
self._add_kv_cache_events()
1580+
attn_metadata = getattr(self.ctx_model_engine,
1581+
'attn_metadata', None)
1582+
kv_cache_dtype_byte_size = getattr(
1583+
self.ctx_model_engine, 'kv_cache_dtype_byte_size',
1584+
None)
1585+
self.resource_manager.update_resources(
1586+
scheduled_batch, attn_metadata,
1587+
kv_cache_dtype_byte_size)
1588+
if self.enable_kv_cache_events:
1589+
self._add_kv_cache_events()
1590+
1591+
self.ctx_request_cv.notify()
15891592

15901593
if self.enable_iter_perf_stats and sample_state is not None:
15911594
iter_stats.iter_counter = self.ctx_model_engine.iter_counter
@@ -1621,7 +1624,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16211624
num_new_active_requests=0,
16221625
new_active_requests_queue_latency_ms=0)
16231626

1624-
with self.sm_disagg_request_lock:
1627+
with self.sm_disagg_lock:
16251628
self._pad_attention_dp_dummy_request()
16261629

16271630
gen_requests = get_generation_requests(self.active_requests)
@@ -1641,7 +1644,8 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16411644
self._pause_requests(scheduled_batch.paused_requests)
16421645

16431646
if scheduled_batch.batch_size > 0:
1644-
self.resource_manager.prepare_resources(scheduled_batch)
1647+
with self.sm_disagg_lock:
1648+
self.resource_manager.prepare_resources(scheduled_batch)
16451649

16461650
# The generation requests that just finished context phase
16471651
# needs to be in front of the batch due to the assumptions
@@ -1662,7 +1666,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16621666
self.previous_batch.sample_state.sampler_event.synchronize(
16631667
)
16641668

1665-
with self.sm_disagg_request_lock:
1669+
with self.sm_disagg_lock:
16661670
if self.previous_batch is not None:
16671671
self._update_requests(
16681672
self.previous_batch.sample_state)
@@ -1672,7 +1676,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16721676
scheduled_batch, batch_outputs)
16731677
assert sample_state is not None, "Sampling failed"
16741678

1675-
with self.sm_disagg_request_lock:
1679+
with self.sm_disagg_lock:
16761680
self._update_request_states(scheduled_batch)
16771681

16781682
if self.previous_batch is not None:

0 commit comments

Comments
 (0)