@@ -1028,6 +1028,9 @@ def estimate(
10281028 if sharder .fused_params
10291029 else KV_CACHING_RATIO
10301030 )
1031+ use_virtual_table : bool = (
1032+ constraints .use_virtual_table if constraints else False
1033+ )
10311034
10321035 # hardcoded as 8 bytes
10331036 # input indices can be of int32, but in TBE they get converted to int64 anyway
@@ -1073,6 +1076,7 @@ def estimate(
10731076 multipass_prefetch_max_pass = mpp_conf .num_passes if mpp_conf else None ,
10741077 key_value_params = key_value_params ,
10751078 kv_cache_load_factor = kv_cache_load_factor ,
1079+ use_virtual_table = use_virtual_table ,
10761080 )
10771081 for shard , storage in zip (sharding_option .shards , shard_storages ):
10781082 shard .storage = storage
@@ -1143,6 +1147,7 @@ def calculate_shard_storages(
11431147 multipass_prefetch_max_pass : Optional [int ] = None ,
11441148 key_value_params : Optional [KeyValueParams ] = None ,
11451149 kv_cache_load_factor : float = KV_CACHING_RATIO ,
1150+ use_virtual_table : bool = False ,
11461151) -> List [Storage ]:
11471152 """
11481153 Calculates estimated storage sizes for each sharded tensor, comprised of input,
@@ -1223,11 +1228,17 @@ def calculate_shard_storages(
12231228 is_inference = is_inference ,
12241229 )
12251230
1226- if compute_kernel in {
1227- EmbeddingComputeKernel .KEY_VALUE .value ,
1228- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1229- EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1230- }:
1231+ if (
1232+ compute_kernel
1233+ in {
1234+ EmbeddingComputeKernel .KEY_VALUE .value ,
1235+ EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1236+ EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1237+ }
1238+ or use_virtual_table
1239+ ):
1240+ # KVZCH does not have dedicated inference compute kernel, so we use use_virtual_table
1241+ # to settup ddr_specific_sizes
12311242 key_value_params = key_value_params or KeyValueParams (
12321243 max_l1_cache_size = 0 , l2_cache_size = 0
12331244 )
0 commit comments