Skip to content

Commit edc91ba

Browse files
authored
[None][fix] Improve type annotations on ResourceManager.get_resource_manager (#9013)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 2e7769d commit edc91ba

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,18 +1290,20 @@ def _free_blocks(self, block_list: list):
12901290

12911291
class ResourceManager:
12921292

1293-
def __init__(self, resource_managers: dict[str, BaseResourceManager]):
1293+
def __init__(self, resource_managers: dict[ResourceManagerType,
1294+
BaseResourceManager]):
12941295
self.resource_managers = OrderedDict(resource_managers)
12951296

1296-
def __call__(self, name: str):
1297-
return self.resource_managers[name]
1297+
def __call__(self, type: ResourceManagerType):
1298+
return self.resource_managers[type]
12981299

1299-
def register_resource_manager(self, name: str,
1300+
def register_resource_manager(self, type: ResourceManagerType,
13001301
resource_manager: BaseResourceManager):
1301-
self.resource_managers[name] = resource_manager
1302+
self.resource_managers[type] = resource_manager
13021303

1303-
def get_resource_manager(self, name: str) -> BaseResourceManager:
1304-
return self.resource_managers.get(name)
1304+
def get_resource_manager(
1305+
self, type: ResourceManagerType) -> Optional[BaseResourceManager]:
1306+
return self.resource_managers.get(type)
13051307

13061308
@nvtx_range("prepare_resources")
13071309
def prepare_resources(self, scheduled_batch: ScheduledRequests):
@@ -1312,8 +1314,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
13121314
@nvtx_range("update_resources")
13131315
def update_resources(self,
13141316
scheduled_batch: ScheduledRequests,
1315-
attn_metadata: "AttentionMetadata" = None,
1316-
kv_cache_dtype_byte_size: float = None):
1317+
attn_metadata: Optional["AttentionMetadata"] = None,
1318+
kv_cache_dtype_byte_size: Optional[float] = None):
13171319
for _, resource_manager in self.resource_managers.items():
13181320
if hasattr(resource_manager, "update_resources"):
13191321
if isinstance(resource_manager, KVCacheManager):
@@ -1328,7 +1330,8 @@ def free_resources(self, request: LlmRequest):
13281330
if hasattr(resource_manager, "free_resources"):
13291331
resource_manager.free_resources(request)
13301332

1331-
def reorder_pipeline(self, resource_manager_list: list[str]):
1333+
def reorder_pipeline(self,
1334+
resource_manager_list: list[ResourceManagerType]):
13321335
assert set(resource_manager_list) == set(self.resource_managers.keys())
13331336
for resource_manager in resource_manager_list:
13341337
self.resource_managers.move_to_end(resource_manager)

0 commit comments

Comments
 (0)