@@ -1197,16 +1197,9 @@ def calculate_shard_storages(
11971197 hbm_storage : int = tensor_storage .get ("hbm" , 0 )
11981198 ddr_storage : int = tensor_storage .get ("ddr" , 0 )
11991199
1200- table_cached : bool = False
1201- if compute_kernel in {
1202- EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1203- EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1204- EmbeddingComputeKernel .KEY_VALUE .value ,
1205- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1206- EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1207- }:
1200+ table_cached = _is_table_cached (compute_kernel )
1201+ if table_cached :
12081202 hbm_storage = round (ddr_storage * caching_ratio )
1209- table_cached = True
12101203
12111204 optimizer_class = getattr (tensor , "_optimizer_classes" , [None ])[0 ]
12121205
@@ -1304,6 +1297,20 @@ def calculate_shard_storages(
13041297 ]
13051298
13061299
1300+ def _is_table_cached (
1301+ compute_kernel : str ,
1302+ ) -> bool :
1303+ if compute_kernel in {
1304+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1305+ EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1306+ EmbeddingComputeKernel .KEY_VALUE .value ,
1307+ EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1308+ EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1309+ }:
1310+ return True
1311+ return False
1312+
1313+
13071314def _calculate_shard_io_sizes (
13081315 sharding_type : str ,
13091316 batch_sizes : List [int ],
@@ -1565,27 +1572,20 @@ def _calculate_storage_specific_sizes(
15651572 is_inference : bool = False ,
15661573 clf : Optional [float ] = None ,
15671574) -> List [int ]:
1568- tensor_sizes : List [int ] = [
1569- (
1570- math .ceil (storage * prod (size ) / prod (shape ))
1571- if sharding_type != ShardingType .DATA_PARALLEL .value
1572- else storage
1573- )
1574- for size in shard_sizes
1575- ]
1576- optimizer_multipler : float = _get_optimizer_multipler (optimizer_class , shape )
1577-
1578- optimizer_sizes : List [int ] = [
1579- math .ceil (tensor_size * optimizer_multipler ) for tensor_size in tensor_sizes
1580- ]
1581-
1582- # If a table has turned on UVM caching (meaning clf is not None), there'll be
1583- # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1584- # cache aux state (note that this is not the cache content itself)
1585- cache_aux_state_sizes : List [int ] = (
1586- [0 ] * len (shard_sizes )
1587- if clf is None
1588- else [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1575+ tensor_sizes : List [int ] = _calculate_tensor_sizes (
1576+ storage ,
1577+ shape ,
1578+ shard_sizes ,
1579+ sharding_type ,
1580+ )
1581+ optimizer_sizes = _calculate_optimizer_sizes (
1582+ tensor_sizes ,
1583+ optimizer_class ,
1584+ shape ,
1585+ )
1586+ cache_aux_state_sizes : List [int ] = _calculate_cache_aux_state_sizes (
1587+ shard_sizes ,
1588+ clf ,
15891589 )
15901590
15911591 return [
@@ -1600,6 +1600,45 @@ def _calculate_storage_specific_sizes(
16001600 ]
16011601
16021602
1603+ def _calculate_tensor_sizes (
1604+ storage : int , shape : torch .Size , shard_sizes : List [List [int ]], sharding_type : str
1605+ ) -> List [int ]:
1606+ return [
1607+ (
1608+ math .ceil (storage * prod (size ) / prod (shape ))
1609+ if sharding_type != ShardingType .DATA_PARALLEL .value
1610+ else storage
1611+ )
1612+ for size in shard_sizes
1613+ ]
1614+
1615+
1616+ # If a table has turned on UVM caching (meaning clf is not None), there'll be
1617+ # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1618+ # cache aux state (note that this is not the cache content itself)
1619+ def _calculate_cache_aux_state_sizes (
1620+ shard_sizes : List [List [int ]], clf : Optional [float ]
1621+ ) -> List [int ]:
1622+ if clf is None :
1623+ return [0 ] * len (shard_sizes )
1624+ return [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1625+
1626+
1627+ def _calculate_optimizer_sizes (
1628+ tensor_sizes : List [int ],
1629+ optimizer_class : Optional [Type [torch .optim .Optimizer ]],
1630+ sharding_tensor_shape : torch .Size ,
1631+ ) -> List [int ]:
1632+ optimizer_multiplier : float = _get_optimizer_multipler (
1633+ optimizer_class ,
1634+ sharding_tensor_shape ,
1635+ )
1636+ optimizer_sizes : List [int ] = [
1637+ math .ceil (tensor_size * optimizer_multiplier ) for tensor_size in tensor_sizes
1638+ ]
1639+ return optimizer_sizes
1640+
1641+
16031642def _get_optimizer_multipler (
16041643 optimizer_class : Optional [Type [torch .optim .Optimizer ]],
16051644 shape : torch .Size ,
0 commit comments