@@ -350,6 +350,18 @@ def allocation_scope(current_stage: ExecutorMemoryType,
350350 validate_feature_combination (llm_args , model_engine , llm_args .sampler_type )
351351
352352 if llm_args .sm_disagg_config is not None :
353+ if llm_args .cache_transceiver_config is not None :
354+ raise ValueError (
355+ "SM-level disaggregation is not compatible with disaggregated serving."
356+ )
357+ if llm_args .parallel_config .world_size > 1 :
358+ raise NotImplementedError (
359+ "SM-level disaggregation is not supported with parallelism." )
360+ if scheduler_config .capacity_scheduler_policy != CapacitySchedulerPolicy .GUARANTEED_NO_EVICT :
361+ raise NotImplementedError (
362+ "SM-level disaggregation is only supported with guaranteed no evict scheduler policy."
363+ )
364+
353365 with allocation_scope (ExecutorMemoryType .MODEL_ENGINE_CTX ,
354366 RestoreMode .PINNED ):
355367 ctx_llm_args = copy .copy (llm_args )
@@ -366,23 +378,6 @@ def allocation_scope(current_stage: ExecutorMemoryType,
366378 else :
367379 ctx_model_engine = None
368380
369- if llm_args .sm_disagg_config is not None :
370- with allocation_scope (ExecutorMemoryType .MODEL_ENGINE_CTX ,
371- RestoreMode .PINNED ):
372- ctx_backend_config = copy .copy (pytorch_backend_config )
373- ctx_backend_config .use_cuda_graph = False
374- ctx_model_engine = PyTorchModelEngine (
375- model_path = checkpoint_dir ,
376- llm_args = llm_args ,
377- mapping = mapping ,
378- attn_runtime_features = attn_runtime_features ,
379- dist = dist ,
380- spec_config = spec_config ,
381- weight_sharing_model = model_engine .model ,
382- )
383- else :
384- ctx_model_engine = None
385-
386381 if has_draft_model_engine :
387382 with allocation_scope (ExecutorMemoryType .MODEL_ENGINE_DRAFT ,
388383 RestoreMode .PINNED ):
0 commit comments