Skip to content

Commit fdd8534

Browse files
emlinfacebook-github-bot
authored andcommitted
Support ZCH vNext in torchrec sharding pass (#3283)
Summary: Pull Request resolved: #3283 fix torchrec_sharding_pass gaps for inference tw sharding Reviewed By: kausv, jingsh Differential Revision: D80183693 fbshipit-source-id: 62199ce112332138434772cfb4ab9780d2a8014c
1 parent 0a1b97a commit fdd8534

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,9 @@ def estimate(
10281028
if sharder.fused_params
10291029
else KV_CACHING_RATIO
10301030
)
1031+
use_virtual_table: bool = (
1032+
constraints.use_virtual_table if constraints else False
1033+
)
10311034

10321035
# hardcoded as 8 bytes
10331036
# input indices can be of int32, but in TBE they get converted to int64 anyway
@@ -1073,6 +1076,7 @@ def estimate(
10731076
multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None,
10741077
key_value_params=key_value_params,
10751078
kv_cache_load_factor=kv_cache_load_factor,
1079+
use_virtual_table=use_virtual_table,
10761080
)
10771081
for shard, storage in zip(sharding_option.shards, shard_storages):
10781082
shard.storage = storage
@@ -1143,6 +1147,7 @@ def calculate_shard_storages(
11431147
multipass_prefetch_max_pass: Optional[int] = None,
11441148
key_value_params: Optional[KeyValueParams] = None,
11451149
kv_cache_load_factor: float = KV_CACHING_RATIO,
1150+
use_virtual_table: bool = False,
11461151
) -> List[Storage]:
11471152
"""
11481153
Calculates estimated storage sizes for each sharded tensor, comprised of input,
@@ -1223,11 +1228,17 @@ def calculate_shard_storages(
12231228
is_inference=is_inference,
12241229
)
12251230

1226-
if compute_kernel in {
1227-
EmbeddingComputeKernel.KEY_VALUE.value,
1228-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1229-
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1230-
}:
1231+
if (
1232+
compute_kernel
1233+
in {
1234+
EmbeddingComputeKernel.KEY_VALUE.value,
1235+
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1236+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1237+
}
1238+
or use_virtual_table
1239+
):
1240+
# KVZCH does not have dedicated inference compute kernel, so we use use_virtual_table
1241+
# to settup ddr_specific_sizes
12311242
key_value_params = key_value_params or KeyValueParams(
12321243
max_l1_cache_size=0, l2_cache_size=0
12331244
)

torchrec/distributed/planner/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ class ParameterConstraints:
722722
or a gpu device.
723723
key_value_params (Optional[KeyValueParams]): key value params for SSD TBE, either for
724724
SSD or PS.
725+
use_virtual_table (bool): is virtual table enabled for this table.
725726
"""
726727

727728
sharding_types: Optional[List[str]] = None
@@ -741,6 +742,7 @@ class ParameterConstraints:
741742
output_dtype: Optional[DataType] = None
742743
device_group: Optional[str] = None
743744
key_value_params: Optional[KeyValueParams] = None
745+
use_virtual_table: bool = False
744746

745747
def __hash__(self) -> int:
746748
hashable_list = [
@@ -759,6 +761,7 @@ def __hash__(self) -> int:
759761
self.output_dtype,
760762
self.device_group,
761763
self.key_value_params,
764+
self.use_virtual_table,
762765
]
763766

764767
return hash_sha256_to_int(hashable_list)

0 commit comments

Comments
 (0)