Skip to content

Commit 06a0a79

Browse files
committed
Remove auto stack kwargs
1 parent 37345dd commit 06a0a79

File tree

3 files changed

+0
-44
lines changed

3 files changed

+0
-44
lines changed

examples/ml_perf/main.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,6 @@ def main(
117117
top_mlp_dims=model_cfg.top_mlp_dims,
118118
num_dcn_layers=model_cfg.num_dcn_layers,
119119
dcn_projection_dim=model_cfg.dcn_projection_dim,
120-
auto_stack_kwargs={
121-
"max_ids_per_partition": model_cfg.max_ids_per_partition,
122-
"max_unique_ids_per_partition": (
123-
model_cfg.max_unique_ids_per_partition
124-
),
125-
},
126120
seed=SEED,
127121
dtype="float32",
128122
name="dlrm_dcn_v2",

examples/ml_perf/model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(
4949
top_mlp_dims: list[int],
5050
num_dcn_layers: int,
5151
dcn_projection_dim: int,
52-
auto_stack_kwargs: dict[str, Any],
5352
seed: int | keras.random.SeedGenerator | None = None,
5453
dtype: str | None = None,
5554
name: str | None = None,
@@ -116,7 +115,6 @@ def __init__(
116115
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
117116
feature_configs=large_emb_feature_configs,
118117
table_stacking="auto",
119-
auto_stack_kwargs=auto_stack_kwargs,
120118
dtype=dtype,
121119
name="embedding_layer",
122120
)
@@ -173,7 +171,6 @@ def __init__(
173171
self.top_mlp_dims = top_mlp_dims
174172
self.num_dcn_layers = num_dcn_layers
175173
self.dcn_projection_dim = dcn_projection_dim
176-
self.auto_stack_kwargs = auto_stack_kwargs
177174

178175
def call(self, inputs: dict[str, Tensor]) -> Tensor:
179176
"""Forward pass of the model.
@@ -280,7 +277,6 @@ def get_config(self):
280277
"top_mlp_dims": self.top_mlp_dims,
281278
"num_dcn_layers": self.num_dcn_layers,
282279
"dcn_projection_dim": self.dcn_projection_dim,
283-
"auto_stack_kwargs": self.auto_stack_kwargs,
284280
"seed": self.seed,
285281
}
286282
)

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -193,39 +193,6 @@ def __call__(
193193
class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
194194
"""JAX implementation of the TPU embedding layer."""
195195

196-
def __init__(self, **kwargs: Any):
197-
# Pull out `auto_stack_kwargs` from `kwargs`.
198-
auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {})
199-
200-
auto_stack_max_ids_per_partition = auto_stack_kwargs.pop(
201-
"max_ids_per_partition", None
202-
)
203-
auto_stack_max_unique_ids_per_partition = auto_stack_kwargs.pop(
204-
"max_unique_ids_per_partition", None
205-
)
206-
207-
# For `max_ids_per_partition` and `max_unique_ids_per_partition`, JTE's
208-
# `auto_stack_tables` expects callables.
209-
def _get_max_ids_per_partition(name: str, batch_size: int) -> int:
210-
return auto_stack_max_ids_per_partition
211-
212-
def _get_max_unique_ids_per_partition(
213-
name: str, batch_size: int
214-
) -> int:
215-
return auto_stack_max_unique_ids_per_partition
216-
217-
if auto_stack_max_ids_per_partition is not None:
218-
auto_stack_kwargs["stack_to_max_ids_per_partition"] = (
219-
_get_max_ids_per_partition
220-
)
221-
if auto_stack_max_unique_ids_per_partition is not None:
222-
auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = (
223-
_get_max_unique_ids_per_partition
224-
)
225-
226-
self._auto_stack_kwargs = auto_stack_kwargs
227-
super().__init__(**kwargs)
228-
229196
def _create_sparsecore_distribution(
230197
self, sparsecore_axis_name: str = "sparsecore"
231198
) -> tuple[
@@ -438,7 +405,6 @@ def sparsecore_build(
438405
feature_specs,
439406
global_device_count,
440407
num_sc_per_device,
441-
# **self._auto_stack_kwargs,
442408
)
443409
else:
444410
raise ValueError(

0 commit comments

Comments
 (0)