@@ -214,9 +214,9 @@ def __init__(self,
214214 self .responses = {}
215215 self .result_wait_queues = {}
216216
217- self .sm_disagg_request_lock = threading .Lock ()
218- self .ctx_request_cv = threading .Condition (self .sm_disagg_request_lock )
219- self .gen_request_cv = threading .Condition (self .sm_disagg_request_lock )
217+ self .sm_disagg_lock = threading .Lock ()
218+ self .ctx_request_cv = threading .Condition (self .sm_disagg_lock )
219+ self .gen_request_cv = threading .Condition (self .sm_disagg_lock )
220220
221221 # kv cache events
222222 self .kv_cache_manager = self .resource_manager .resource_managers .get (
@@ -1536,7 +1536,7 @@ def _executor_loop_sm_disagg_ctx(self, stream):
15361536 self .executor_request_queue .
15371537 get_new_active_requests_queue_latency ())
15381538
1539- with self .sm_disagg_request_lock :
1539+ with self .sm_disagg_lock :
15401540 ctx_requests = get_context_requests (self .active_requests )
15411541 if self .is_shutdown and len (ctx_requests ) == 0 \
15421542 and self .executor_request_queue .get_waiting_queue_size () == 0 :
@@ -1559,7 +1559,8 @@ def _executor_loop_sm_disagg_ctx(self, stream):
15591559
15601560 if scheduled_batch .batch_size > 0 or (
15611561 self .enable_attention_dp and self .dist .tp_size > 1 ):
1562- self .resource_manager .prepare_resources (scheduled_batch )
1562+ with self .sm_disagg_lock :
1563+ self .resource_manager .prepare_resources (scheduled_batch )
15631564
15641565 with torch .cuda .stream (stream ):
15651566 batch_outputs = self ._forward_step (
@@ -1569,23 +1570,25 @@ def _executor_loop_sm_disagg_ctx(self, stream):
15691570 # To avoid long sync time in critical section below
15701571 sample_state .sampler_event .synchronize ()
15711572
1572- with self .sm_disagg_request_lock :
1573+ with self .sm_disagg_lock :
15731574 self ._update_request_states (scheduled_batch )
15741575 self ._update_requests (sample_state ,
15751576 self .resource_manager )
15761577 self ._handle_canceled_requests ()
15771578 finished_requests = self ._handle_responses ()
1578- self .ctx_request_cv .notify ()
15791579
1580- attn_metadata = getattr (self .ctx_model_engine ,
1581- 'attn_metadata' , None )
1582- kv_cache_dtype_byte_size = getattr (
1583- self .ctx_model_engine , 'kv_cache_dtype_byte_size' , None )
1584- self .resource_manager .update_resources (
1585- scheduled_batch , attn_metadata ,
1586- kv_cache_dtype_byte_size )
1587- if self .enable_kv_cache_events :
1588- self ._add_kv_cache_events ()
1580+ attn_metadata = getattr (self .ctx_model_engine ,
1581+ 'attn_metadata' , None )
1582+ kv_cache_dtype_byte_size = getattr (
1583+ self .ctx_model_engine , 'kv_cache_dtype_byte_size' ,
1584+ None )
1585+ self .resource_manager .update_resources (
1586+ scheduled_batch , attn_metadata ,
1587+ kv_cache_dtype_byte_size )
1588+ if self .enable_kv_cache_events :
1589+ self ._add_kv_cache_events ()
1590+
1591+ self .ctx_request_cv .notify ()
15891592
15901593 if self .enable_iter_perf_stats and sample_state is not None :
15911594 iter_stats .iter_counter = self .ctx_model_engine .iter_counter
@@ -1621,7 +1624,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16211624 num_new_active_requests = 0 ,
16221625 new_active_requests_queue_latency_ms = 0 )
16231626
1624- with self .sm_disagg_request_lock :
1627+ with self .sm_disagg_lock :
16251628 self ._pad_attention_dp_dummy_request ()
16261629
16271630 gen_requests = get_generation_requests (self .active_requests )
@@ -1641,7 +1644,8 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16411644 self ._pause_requests (scheduled_batch .paused_requests )
16421645
16431646 if scheduled_batch .batch_size > 0 :
1644- self .resource_manager .prepare_resources (scheduled_batch )
1647+ with self .sm_disagg_lock :
1648+ self .resource_manager .prepare_resources (scheduled_batch )
16451649
16461650 # The generation requests that just finished context phase
16471651 # needs to be in front of the batch due to the assumptions
@@ -1662,7 +1666,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16621666 self .previous_batch .sample_state .sampler_event .synchronize (
16631667 )
16641668
1665- with self .sm_disagg_request_lock :
1669+ with self .sm_disagg_lock :
16661670 if self .previous_batch is not None :
16671671 self ._update_requests (
16681672 self .previous_batch .sample_state )
@@ -1672,7 +1676,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16721676 scheduled_batch , batch_outputs )
16731677 assert sample_state is not None , "Sampling failed"
16741678
1675- with self .sm_disagg_request_lock :
1679+ with self .sm_disagg_lock :
16761680 self ._update_request_states (scheduled_batch )
16771681
16781682 if self .previous_batch is not None :
0 commit comments