File tree Expand file tree Collapse file tree 5 files changed +0
-21
lines changed Expand file tree Collapse file tree 5 files changed +0
-21
lines changed Original file line number Diff line number Diff line change @@ -565,7 +565,6 @@ def _group_tables_per_rank(
565565 ),
566566 _prefetch_and_cached (table ),
567567 table .use_virtual_table if is_inference else None ,
568- table .enable_embedding_update ,
569568 )
570569 # micromanage the order of we traverse the groups to ensure backwards compatibility
571570 if grouping_key not in groups :
@@ -582,7 +581,6 @@ def _group_tables_per_rank(
582581 _ ,
583582 _ ,
584583 use_virtual_table ,
585- enable_embedding_update ,
586584 ) = grouping_key
587585 grouped_tables = groups [grouping_key ]
588586 # remove non-native fused params
@@ -604,7 +602,6 @@ def _group_tables_per_rank(
604602 compute_kernel = compute_kernel_type ,
605603 embedding_tables = grouped_tables ,
606604 fused_params = per_tbe_fused_params ,
607- enable_embedding_update = enable_embedding_update ,
608605 )
609606 )
610607 return grouped_embedding_configs
Original file line number Diff line number Diff line change @@ -251,7 +251,6 @@ class GroupedEmbeddingConfig:
251251 compute_kernel : EmbeddingComputeKernel
252252 embedding_tables : List [ShardedEmbeddingTable ]
253253 fused_params : Optional [Dict [str , Any ]] = None
254- enable_embedding_update : bool = False
255254
256255 def feature_hash_sizes (self ) -> List [int ]:
257256 feature_hash_sizes = []
Original file line number Diff line number Diff line change @@ -223,7 +223,6 @@ def _shard(
223223 total_num_buckets = info .embedding_config .total_num_buckets ,
224224 use_virtual_table = info .embedding_config .use_virtual_table ,
225225 virtual_table_eviction_policy = info .embedding_config .virtual_table_eviction_policy ,
226- enable_embedding_update = info .embedding_config .enable_embedding_update ,
227226 )
228227 )
229228 return tables_per_rank
@@ -279,20 +278,6 @@ def _get_feature_hash_sizes(self) -> List[int]:
279278 feature_hash_sizes .extend (group_config .feature_hash_sizes ())
280279 return feature_hash_sizes
281280
282- def _get_num_writable_features (self ) -> int :
283- return sum (
284- group_config .num_features ()
285- for group_config in self ._grouped_embedding_configs
286- if group_config .enable_embedding_update
287- )
288-
289- def _get_writable_feature_hash_sizes (self ) -> List [int ]:
290- feature_hash_sizes : List [int ] = []
291- for group_config in self ._grouped_embedding_configs :
292- if group_config .enable_embedding_update :
293- feature_hash_sizes .extend (group_config .feature_hash_sizes ())
294- return feature_hash_sizes
295-
296281
297282class RwSparseFeaturesDist (BaseSparseFeaturesDist [KeyedJaggedTensor ]):
298283 """
Original file line number Diff line number Diff line change @@ -370,7 +370,6 @@ class BaseEmbeddingConfig:
370370 total_num_buckets : Optional [int ] = None
371371 use_virtual_table : bool = False
372372 virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
373- enable_embedding_update : bool = False
374373
375374 def get_weight_init_max (self ) -> float :
376375 if self .weight_init_max is None :
Original file line number Diff line number Diff line change @@ -43,7 +43,6 @@ class StableEmbeddingBagConfig:
4343 total_num_buckets : Optional [int ] = None
4444 use_virtual_table : bool = False
4545 virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
46- enable_embedding_update : bool = False
4746 pooling : PoolingType = PoolingType .SUM
4847
4948
You can’t perform that action at this time.
0 commit comments