From 2358d7fb5ab2bee090e79e85d3589dec7c0165cc Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Mon, 10 Nov 2025 10:16:11 +0100 Subject: [PATCH] fix: type annotations on ResourceManager.get_resource_manager Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index b727dcb45ff..d16887deb5c 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1290,18 +1290,20 @@ def _free_blocks(self, block_list: list): class ResourceManager: - def __init__(self, resource_managers: dict[str, BaseResourceManager]): + def __init__(self, resource_managers: dict[ResourceManagerType, + BaseResourceManager]): self.resource_managers = OrderedDict(resource_managers) - def __call__(self, name: str): - return self.resource_managers[name] + def __call__(self, type: ResourceManagerType): + return self.resource_managers[type] - def register_resource_manager(self, name: str, + def register_resource_manager(self, type: ResourceManagerType, resource_manager: BaseResourceManager): - self.resource_managers[name] = resource_manager + self.resource_managers[type] = resource_manager - def get_resource_manager(self, name: str) -> BaseResourceManager: - return self.resource_managers.get(name) + def get_resource_manager( + self, type: ResourceManagerType) -> Optional[BaseResourceManager]: + return self.resource_managers.get(type) @nvtx_range("prepare_resources") def prepare_resources(self, scheduled_batch: ScheduledRequests): @@ -1312,8 +1314,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): @nvtx_range("update_resources") def update_resources(self, scheduled_batch: ScheduledRequests, - attn_metadata: "AttentionMetadata" = None, - kv_cache_dtype_byte_size: float = None): + attn_metadata: Optional["AttentionMetadata"] = None, + kv_cache_dtype_byte_size: Optional[float] = None): for _, resource_manager in self.resource_managers.items(): if hasattr(resource_manager, "update_resources"): if isinstance(resource_manager, KVCacheManager): @@ -1328,7 +1330,8 @@ def free_resources(self, request: LlmRequest): if hasattr(resource_manager, "free_resources"): resource_manager.free_resources(request) - def reorder_pipeline(self, resource_manager_list: list[str]): + def reorder_pipeline(self, + resource_manager_list: list[ResourceManagerType]): assert set(resource_manager_list) == set(self.resource_managers.keys()) for resource_manager in resource_manager_list: self.resource_managers.move_to_end(resource_manager)