From 5ca57a7b425e004d42624fdd961e058f1aefad23 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:03:25 +0530 Subject: [PATCH 001/279] Add DLRM-DCNv2 example for MLPerf --- .../configs/datasets/dummy_dataset.py | 140 ++++++++ .../configs/models/default_model.py | 19 ++ .../configs/training/default_training.py | 7 + examples/mlperf_dlrm_dcnv2/configs/v6e_16.py | 15 + examples/mlperf_dlrm_dcnv2/configs/v6e_8.py | 15 + examples/mlperf_dlrm_dcnv2/dataloader.py | 69 ++++ examples/mlperf_dlrm_dcnv2/main.py | 232 +++++++++++++ examples/mlperf_dlrm_dcnv2/model.py | 321 ++++++++++++++++++ 8 files changed, 818 insertions(+) create mode 100644 examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/models/default_model.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/training/default_training.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/v6e_16.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/v6e_8.py create mode 100644 examples/mlperf_dlrm_dcnv2/dataloader.py create mode 100644 examples/mlperf_dlrm_dcnv2/main.py create mode 100644 examples/mlperf_dlrm_dcnv2/model.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py b/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py new file mode 100644 index 00000000..469ad206 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py @@ -0,0 +1,140 @@ +from keras.utils import Config + +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = None +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(13)] +dataset_config.sparse = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "multi_hot_size": 3, + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "multi_hot_size": 2, + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "multi_hot_size": 2, + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "multi_hot_size": 6, + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "multi_hot_size": 7, + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "multi_hot_size": 3, + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "multi_hot_size": 8, + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "multi_hot_size": 6, + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "multi_hot_size": 9, + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "multi_hot_size": 5, + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "multi_hot_size": 12, + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "multi_hot_size": 100, + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "multi_hot_size": 27, + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "multi_hot_size": 10, + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "multi_hot_size": 3, + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "multi_hot_size": 1, + }, +] diff --git a/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py b/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py new file mode 100644 index 00000000..1a07e9d9 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py @@ -0,0 +1,19 @@ +from keras.utils import Config + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 4096 +model_config.max_unique_ids_per_partition = 2048 +model_config.learning_rate = 0.005 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 diff --git a/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py b/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py new file mode 100644 index 00000000..b758bc59 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py @@ -0,0 +1,7 @@ +from keras.utils import Config + +# === Training Hyperparameters === +training_config = Config() +training_config.learning_rate = 0.005 +training_config.global_batch_size = 128 +training_config.num_epochs = 1 diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py b/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py new file mode 100644 index 00000000..b246d773 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py @@ -0,0 +1,15 @@ +from configs.datasets.dummy_dataset import dataset_config +from configs.models.default_model import model_config +from configs.training.default_training import training_config +from keras.utils import Config + +config = Config() + +config.experiment_name = "v6e_16" +config.model_dir = "./v6e_16" + +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py b/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py new file mode 100644 index 00000000..552f25d0 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py @@ -0,0 +1,15 @@ +from configs.datasets.dummy_dataset import dataset_config +from configs.models.default_model import model_config +from configs.training.default_training import training_config +from keras.utils import Config + +config = Config() + +config.experiment_name = "v6e_8" +config.model_dir = "./v6e_8" + +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/mlperf_dlrm_dcnv2/dataloader.py b/examples/mlperf_dlrm_dcnv2/dataloader.py new file mode 100644 index 00000000..467f390a --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/dataloader.py @@ -0,0 +1,69 @@ +import numpy as np +import tensorflow as tf + + +def _get_dummy_batch(batch_size, large_emb_features, small_emb_features): + """Returns a dummy batch of data in the final desired structure.""" + + # Labels + data = { + "clicked": np.random.randint(0, 2, size=(batch_size,), dtype=np.int64) + } + + # Dense features + dense_input_list = [ + np.random.uniform(0.0, 0.9, size=(batch_size, 1)).astype(np.float32) + for _ in range(13) + ] + data["dense_input"] = np.concatenate(dense_input_list, axis=-1) + + # Sparse features + large_emb_inputs = {} + for large_emb_feature in large_emb_features: + vocabulary_size = large_emb_feature["vocabulary_size"] + multi_hot_size = large_emb_feature["multi_hot_size"] + idx = large_emb_feature["name"].split("-")[-1] + + large_emb_inputs[f"cat_{idx}_id"] = np.random.randint( + low=0, + high=vocabulary_size, + size=(batch_size, multi_hot_size), + dtype=np.int64, + ) + + data["large_emb_inputs"] = large_emb_inputs + + # Dense lookup features + small_emb_inputs = {} + for small_emb_feature in small_emb_features: + vocabulary_size = small_emb_feature["vocabulary_size"] + multi_hot_size = small_emb_feature["multi_hot_size"] + idx = small_emb_feature["name"].split("-")[-1] + + # TODO: We don't need this custom renaming. Remove later, when we + # shift from dummy data to actual data. + small_emb_inputs[f"cat_{idx}_id"] = np.random.randint( + low=0, + high=vocabulary_size, + size=(batch_size, multi_hot_size), + dtype=np.int64, + ) + + if small_emb_inputs: + data["small_emb_inputs"] = small_emb_inputs + + return data + + +def create_dummy_dataset(batch_size, large_emb_features, small_emb_features): + """Creates a TF dataset from cached dummy data of the final batch size.""" + dummy_data = _get_dummy_batch( + batch_size, large_emb_features, small_emb_features + ) + + # Separate labels from features to create a `(features, labels)` tuple. + labels = dummy_data.pop("clicked") + features = dummy_data + + dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) + return dataset diff --git a/examples/mlperf_dlrm_dcnv2/main.py b/examples/mlperf_dlrm_dcnv2/main.py new file mode 100644 index 00000000..115648cf --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/main.py @@ -0,0 +1,232 @@ +import argparse +import importlib +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import keras + +import keras_rs + +from .dataloader import create_dummy_dataset +from .model import DLRMDCNV2 + +SEED = 1337 + + +def main( + file_pattern, + dense_features, + large_emb_features, + small_emb_features, + label, + embedding_dim, + allow_id_dropping, + max_ids_per_partition, + max_unique_ids_per_partition, + embedding_learning_rate, + bottom_mlp_dims, + top_mlp_dims, + num_dcn_layers, + dcn_projection_dim, + learning_rate, + global_batch_size, + num_epochs, +): + # Set DDP as Keras distribution strategy + devices = keras.distribution.list_devices(device_type="tpu") + distribution = keras.distribution.DataParallel(devices=devices) + keras.distribution.set_distribution(distribution) + num_processes = distribution._num_process() + + per_host_batch_size = global_batch_size // num_processes + + # === Distributed embeddings' configs for sparse features === + feature_configs = {} + for large_emb_feature in large_emb_features: + # Rename these features to something shorter; was facing some weird + # issues with the longer names. + feature_name = ( + large_emb_feature["name"] + .replace("-", "_") + .replace("egorical_feature", "") + ) + vocabulary_size = large_emb_feature["vocabulary_size"] + multi_hot_size = large_emb_feature["multi_hot_size"] + + table_config = keras_rs.layers.TableConfig( + name=f"{feature_name}_table", + vocabulary_size=vocabulary_size, + embedding_dim=embedding_dim, + # TODO(abheesht): Verify. + initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + seed=SEED, + ), + optimizer=keras.optimizers.Adagrad( + learning_rate=embedding_learning_rate + ), + combiner="sum", + placement="sparsecore", + # TODO: These two args are not getting passed down to + # `jax-tpu-embedding` properly, seems like. + max_ids_per_partition=max_ids_per_partition, + max_unique_ids_per_partition=max_unique_ids_per_partition, + ) + feature_configs[f"{feature_name}_id"] = keras_rs.layers.FeatureConfig( + name=feature_name.replace("id", ""), + table=table_config, + # TODO: Verify whether it should be `(bsz, 1)` or + # `(bsz, multi_hot_size)`. + input_shape=(per_host_batch_size, multi_hot_size), + output_shape=(per_host_batch_size, embedding_dim), + ) + + # === Instantiate model === + # We instantiate the model first, because we need to preprocess sparse + # inputs using the distributed embedding layer defined inside the model + # class. + print("===== Initialising model =====") + model = DLRMDCNV2( + large_emb_feature_configs=feature_configs, + small_emb_features=small_emb_features, + embedding_dim=embedding_dim, + bottom_mlp_dims=bottom_mlp_dims, + top_mlp_dims=top_mlp_dims, + num_dcn_layers=num_dcn_layers, + dcn_projection_dim=dcn_projection_dim, + seed=SEED, + dtype="float32", + name="dlrm_dcn_v2", + ) + model.compile( + loss=keras.losses.BinaryCrossentropy(), + optimizer=keras.optimizers.Adagrad(learning_rate=learning_rate), + metrics=[keras.metrics.BinaryAccuracy()], + ) + + # === Load dataset === + print("===== Loading dataset =====") + train_ds = create_dummy_dataset( + batch_size=per_host_batch_size, + large_emb_features=large_emb_features, + small_emb_features=small_emb_features, + ) + # For the multi-host case, the dataset has to be distributed manually. + # See note here: + # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. + if num_processes > 1: + train_ds = distribution.distribute_dataset(train_ds) + distribution.auto_shard_dataset = False + + def generator(dataset, training=False): + """Converts tf.data Dataset to a Python generator and preprocesses + sparse features. + """ + for features, labels in dataset: + preprocessed_large_embeddings = model.embedding_layer.preprocess( + features["large_emb_inputs"], training=training + ) + + x = { + "dense_input": features["dense_input"], + "large_emb_inputs": preprocessed_large_embeddings, + "small_emb_inputs": features["small_emb_inputs"], + } + y = labels + yield (x, y) + + train_generator = generator(train_ds, training=True) + for first_batch in train_generator: + model(first_batch[0]) + break + + # Train the model. + model.fit(train_generator, epochs=1) + + +if __name__ == "__main__": + keras.config.disable_traceback_filtering() + + print("====== Launching train script =======") + parser = argparse.ArgumentParser( + description=( + "Benchmark the DLRM-DCNv2 model on the Criteo dataset (MLPerf)" + ) + ) + parser.add_argument( + "--config_name", type=str, help="Name of the `.py` config file." + ) + args = parser.parse_args() + + print(f"===== Reading config from {args.config_name} ======") + config = getattr(importlib.import_module("configs"), args.config_name) + + # === Unpack args from config === + + # == Dataset config == + ds_cfg = config["dataset"] + # File path + file_pattern = ds_cfg["file_pattern"] + # Features + label = ds_cfg["label"] + dense_features = ds_cfg["dense"] + emb_features = ds_cfg["sparse"] + + # == Model config == + model_cfg = config["model"] + # Embedding + embedding_dim = model_cfg["embedding_dim"] + allow_id_dropping = model_cfg["allow_id_dropping"] + embedding_threshold = model_cfg["embedding_threshold"] + max_ids_per_partition = model_cfg["max_ids_per_partition"] + max_unique_ids_per_partition = model_cfg["max_unique_ids_per_partition"] + embedding_learning_rate = model_cfg["learning_rate"] + # MLP + bottom_mlp_dims = model_cfg["bottom_mlp_dims"] + top_mlp_dims = model_cfg["top_mlp_dims"] + # DCN + num_dcn_layers = model_cfg["num_dcn_layers"] + dcn_projection_dim = model_cfg["dcn_projection_dim"] + + # == Training config == + training_cfg = config["training"] + learning_rate = training_cfg["learning_rate"] + global_batch_size = training_cfg["global_batch_size"] + num_epochs = training_cfg["num_epochs"] + + # For features which have vocabulary_size < embedding_threshold, we can + # just do a normal dense lookup for those instead of have distributed + # embeddings. We could ideally pass `placement = default_device` to + # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this + # separation of features), but doing it that way will necessarily require + # a separate optimiser for the embedding layer. + small_emb_features = [] + large_emb_features = [] + for emb_feature in emb_features: + if emb_feature["vocabulary_size"] < embedding_threshold: + small_emb_features.append(emb_feature) + else: + large_emb_features.append(emb_feature) + + main( + file_pattern, + dense_features, + large_emb_features, + small_emb_features, + label, + embedding_dim, + allow_id_dropping, + max_ids_per_partition, + max_unique_ids_per_partition, + embedding_learning_rate, + bottom_mlp_dims, + top_mlp_dims, + num_dcn_layers, + dcn_projection_dim, + learning_rate, + global_batch_size, + num_epochs, + ) diff --git a/examples/mlperf_dlrm_dcnv2/model.py b/examples/mlperf_dlrm_dcnv2/model.py new file mode 100644 index 00000000..b77579f9 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/model.py @@ -0,0 +1,321 @@ +from typing import Any, TypeAlias + +import keras +from keras import ops + +import keras_rs + +Tensor: TypeAlias = Any + + +def _clone_initializer( + initializer: keras.initializers.Initializer, + seed: int | keras.random.SeedGenerator, +): + """Clones the provided initializer with a new seed. + + This function creates a new instance of a Keras initializer from an + existing one, but with a different seed. This is useful for ensuring + different weights in a model are initialized with different seeds. + + Args: + initializer: a keras.initializers.Initializer instance. The initializer + to be cloned. + seed: int, or a keras.random.SeedGenerator() instance. The random seed. + + Returns: + A new `keras.initializers.Initializer` instance configured with the + provided seed. + """ + config = initializer.get_config() + config.pop("seed") + config = {**config, "seed": seed} + initializer_class: type[keras.initializers.Initializer] = ( + initializer.__class__ + ) + return initializer_class.from_config(config) + + +class DLRMDCNV2(keras.Model): + def __init__( + self, + large_emb_feature_configs: dict[str, keras_rs.layers.FeatureConfig], + small_emb_features: list, + embedding_dim: int, + bottom_mlp_dims: list[int], + top_mlp_dims: list[int], + num_dcn_layers: int, + dcn_projection_dim: int, + seed: int | keras.random.SeedGenerator | None = None, + dtype: str | None = None, + name: str | None = None, + **kwargs: Any, + ): + """DLRM-DCNv2 model. + + The model processes two types of input features: + 1. Dense Features: Continuous-valued features that are processed by + a multi-layer perceptron (the "bottom MLP"). + 2. Sparse Features: High-cardinality categorical features that are + first mapped into low-dimensional embedding vectors using the + `keras_rs.layers.DistributedEmbedding` layer. This layer is highly + optimized for large-scale recommendation models, especially on TPUs + with SparseCore, as it can shard large embedding tables across + multiple accelerator chips for improved performance. On other + hardware (GPUs, CPUs), it functions like a standard embedding layer. + + The output of the bottom MLP and the embedding vectors are then + concatenated and fed into a DCN block for learning feature interactions. + The output of the DCN block is then processed by another MLP + (the "top MLP") to produce a final prediction. + + Args: + large_emb_feature_configs: A dictionary with features names as keys + and `keras_rs.layers.FeatureConfig` objects as values. These + configs link features to their corresponding embedding tables + (`keras_rs.layers.TableConfig`), specifying parameters like + vocabulary size, embedding dimension, and hardware placement + strategy. + bottom_mlp_dims: A list of integers specifying the number of units + in each layer of the bottom MLP. + top_mlp_dims: A list of integers specifying the number of units in + each layer of the top MLP. The last value is the final output + dimension (e.g., 1 for binary classification). + num_dcn_layers: The number of feature-crossing layers in the DCNv2 + block. + dcn_projection_dim: The projection dimension used within each DCNv2 + cross-layer. + seed: The random seed. + dtype: Optional dtype. + name: The name of the layer. + """ + super().__init__(dtype=dtype, name=name, **kwargs) + self.seed = seed + + # === Layers ==== + + # Bottom MLP for encoding dense features + self.bottom_mlp = keras.Sequential( + self._get_mlp_layers( + dims=bottom_mlp_dims, + intermediate_activation="relu", + final_activation="relu", + ), + name="bottom_mlp", + ) + # Distributed embeddings for large embedding tables + self.embedding_layer = keras_rs.layers.DistributedEmbedding( + feature_configs=large_emb_feature_configs, + table_stacking="auto_stacking", + dtype=dtype, + name="embedding_layer", + ) + # Embedding layers for small embedding tables + self.small_embedding_layers = None + if small_emb_features: + self.small_embedding_layers = [ + keras.layers.Embedding( + input_dim=small_emb_feature["vocabulary_size"], + output_dim=embedding_dim, + embeddings_initializer="zeros", + name=f"small_embedding_layer_{i}", + ) + for i, small_emb_feature in enumerate(small_emb_features) + ] + # DCN for "interactions" + self.dcn_block = DCNBlock( + num_layers=num_dcn_layers, + projection_dim=dcn_projection_dim, + seed=seed, + dtype=dtype, + name="dcn_block", + ) + # Top MLP for predictions + self.top_mlp = keras.Sequential( + self._get_mlp_layers( + dims=top_mlp_dims, + intermediate_activation="relu", + final_activation="sigmoid", + ), + name="top_mlp", + ) + + # === Passed attributes === + self.large_emb_feature_configs = large_emb_feature_configs + self.small_emb_features = small_emb_features + self.embedding_dim = embedding_dim + self.bottom_mlp_dims = bottom_mlp_dims + self.top_mlp_dims = top_mlp_dims + self.num_dcn_layers = num_dcn_layers + self.dcn_projection_dim = dcn_projection_dim + + def call(self, inputs: dict[str, Tensor]) -> Tensor: + """Forward pass of the model. + + Args: + inputs: A dictionary containing `"dense_features"` and + `"preprocessed_large_emb_features"` as keys. + """ + # Inputs + dense_input = inputs["dense_input"] + large_emb_inputs = inputs["large_emb_inputs"] + + # Embed features. + dense_output = self.bottom_mlp(dense_input) + # jax.debug.print("dense_ouput {}", dense_output.shape) + large_embeddings = self.embedding_layer(large_emb_inputs) + small_embeddings = [] + if self.small_emb_features: + small_emb_inputs = inputs["small_emb_inputs"] + for small_emb_input, embedding_layer in zip( + small_emb_inputs.values(), self.small_embedding_layers + ): + embedding = embedding_layer(small_emb_input) + embedding = ops.sum(embedding, axis=-2) + small_embeddings.append(embedding) + + small_embeddings = ops.concatenate(small_embeddings, axis=-1) + + # Interaction + x = ops.concatenate( + [dense_output, small_embeddings, *large_embeddings.values()], + axis=-1, + ) + # jax.debug.print("x {}", x.shape) + x = self.dcn_block(x) + + # Predictions + outputs = self.top_mlp(x) + return outputs + + def _get_mlp_layers( + self, + dims: list[int], + intermediate_activation: str | keras.layers.Activation, + final_activation: str | keras.layers.Activation, + ) -> list[keras.layers.Layer]: + """Creates a list of Dense layers. + + Args: + dims: list. Output dimensions of the dense layers to be created. + intermediate_activation: string or `keras.layers.Activation`. The + activation to be used in all layers, save the last. + final_activation: str or `keras.layers.Activation`. The activation + to be used in the last layer. + + Returns: + A list of `keras.layers.Dense` layers. + """ + initializer = keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + seed=self.seed, + ) + + layers = [ + keras.layers.Dense( + units=dim, + activation=intermediate_activation, + kernel_initializer=_clone_initializer( + initializer, seed=self.seed + ), + bias_initializer=_clone_initializer( + initializer, seed=self.seed + ), + dtype=self.dtype, + ) + for dim in dims[:-1] + ] + layers += [ + keras.layers.Dense( + units=dims[-1], + activation=final_activation, + kernel_initializer=_clone_initializer( + initializer, seed=self.seed + ), + bias_initializer=_clone_initializer( + initializer, seed=self.seed + ), + dtype=self.dtype, + ) + ] + return layers + + def get_config(self): + """Returns the config of the model.""" + config = super().get_config() + config.update( + { + "large_emb_feature_configs": self.large_emb_feature_configs, + "small_emb_features": self.small_emb_features, + "embedding_dim": self.embedding_dim, + "bottom_mlp_dims": self.bottom_mlp_dims, + "top_mlp_dims": self.top_mlp_dims, + "num_dcn_layers": self.num_dcn_layers, + "dcn_projection_dim": self.dcn_projection_dim, + "seed": self.seed, + } + ) + return config + + +class DCNBlock(keras.layers.Layer): + def __init__( + self, + num_layers: int, + projection_dim: int, + seed: int | keras.random.SeedGenerator, + dtype: str | None = None, + name: str | None = None, + **kwargs, + ): + """ + A block of Deep & Cross Network V2 (DCNv2) layers. + + This layer implements the "cross network" part of the DCNv2 architecture + by stacking multiple `keras_rs.layers.FeatureCross` layers, which learn + feature interactions. + + Args: + num_layers: The number of `FeatureCross` layers to stack. + projection_dim: The dimensionality of the low-rank projection used + within each cross layer. + seed: The random seed for initializers. + dtype: Optional dtype. + name: The name of the layer. + """ + super().__init__(dtype=dtype, name=name, **kwargs) + + # Layers + self.layers = [ + keras_rs.layers.FeatureCross( + projection_dim=projection_dim, + kernel_initializer=keras.initializers.GlorotUniform(seed=seed), + bias_initializer="zeros", + dtype=dtype, + ) + for _ in range(num_layers) + ] + + # Passed attributes + self.num_layers = num_layers + self.projection_dim = projection_dim + self.seed = seed + + def call(self, x0): + xl = x0 + for layer in self.layers: + xl = layer(x0, xl) + return xl + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "projection_dim": self.projection_dim, + "seed": self.seed, + } + ) + return config From 090098df8d03d99cd540f539d0b3ad4fb892ea1a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:05:15 +0530 Subject: [PATCH 002/279] Fix table_stacking arg --- examples/mlperf_dlrm_dcnv2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf_dlrm_dcnv2/model.py b/examples/mlperf_dlrm_dcnv2/model.py index b77579f9..76cf3d09 100644 --- a/examples/mlperf_dlrm_dcnv2/model.py +++ b/examples/mlperf_dlrm_dcnv2/model.py @@ -106,7 +106,7 @@ def __init__( # Distributed embeddings for large embedding tables self.embedding_layer = keras_rs.layers.DistributedEmbedding( feature_configs=large_emb_feature_configs, - table_stacking="auto_stacking", + table_stacking="auto", dtype=dtype, name="embedding_layer", ) From 1dd94228bf168142928797223699b2339c7bd401 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:08:17 +0530 Subject: [PATCH 003/279] Rename dir --- .../configs/datasets/dummy_dataset.py | 0 .../configs/models/default_model.py | 0 .../configs/training/default_training.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_16.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_8.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/dataloader.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/main.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/model.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/datasets/dummy_dataset.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/models/default_model.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/training/default_training.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_16.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_8.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/dataloader.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/main.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/model.py (100%) diff --git a/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py rename to examples/ml_perf/configs/datasets/dummy_dataset.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py b/examples/ml_perf/configs/models/default_model.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/models/default_model.py rename to examples/ml_perf/configs/models/default_model.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py b/examples/ml_perf/configs/training/default_training.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/training/default_training.py rename to examples/ml_perf/configs/training/default_training.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/v6e_16.py rename to examples/ml_perf/configs/v6e_16.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/v6e_8.py rename to examples/ml_perf/configs/v6e_8.py diff --git a/examples/mlperf_dlrm_dcnv2/dataloader.py b/examples/ml_perf/dataloader.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/dataloader.py rename to examples/ml_perf/dataloader.py diff --git a/examples/mlperf_dlrm_dcnv2/main.py b/examples/ml_perf/main.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/main.py rename to examples/ml_perf/main.py diff --git a/examples/mlperf_dlrm_dcnv2/model.py b/examples/ml_perf/model.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/model.py rename to examples/ml_perf/model.py From 25233c9e87b03210e1e4b2e8ee80f4a3d26699dd Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:34:48 +0530 Subject: [PATCH 004/279] Add blank __init__ file to configs --- examples/ml_perf/configs/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/ml_perf/configs/__init__.py diff --git a/examples/ml_perf/configs/__init__.py b/examples/ml_perf/configs/__init__.py new file mode 100644 index 00000000..e69de29b From 3d519d2b6f0e9c51c43981d2bdb5ac19b1a336f7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:37:50 +0530 Subject: [PATCH 005/279] Fix imports --- examples/ml_perf/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 115648cf..d0b49c9e 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -162,7 +162,10 @@ def generator(dataset, training=False): args = parser.parse_args() print(f"===== Reading config from {args.config_name} ======") - config = getattr(importlib.import_module("configs"), args.config_name) + config = getattr( + importlib.import_module(".configs", package=__package__), + args.config_name + ) # === Unpack args from config === From eef6568dc05bc25f375605f4f57409e1ff418b2e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:42:48 +0530 Subject: [PATCH 006/279] Fix imports --- examples/ml_perf/main.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d0b49c9e..96412c81 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -163,9 +163,8 @@ def generator(dataset, training=False): print(f"===== Reading config from {args.config_name} ======") config = getattr( - importlib.import_module(".configs", package=__package__), - args.config_name - ) + importlib.import_module(f".configs.{args.config_name}", package=__package__) + ).config # === Unpack args from config === From fc77ad49986fe0cbebc8520526b61290519511fc Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:45:12 +0530 Subject: [PATCH 007/279] Fix imports --- examples/ml_perf/configs/datasets/__init__.py | 0 examples/ml_perf/configs/models/__init__.py | 0 examples/ml_perf/configs/training/__init__.py | 0 examples/ml_perf/configs/v6e_16.py | 6 +++--- examples/ml_perf/configs/v6e_8.py | 6 +++--- 5 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 examples/ml_perf/configs/datasets/__init__.py create mode 100644 examples/ml_perf/configs/models/__init__.py create mode 100644 examples/ml_perf/configs/training/__init__.py diff --git a/examples/ml_perf/configs/datasets/__init__.py b/examples/ml_perf/configs/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/models/__init__.py b/examples/ml_perf/configs/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/training/__init__.py b/examples/ml_perf/configs/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index b246d773..c9c50aff 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -1,6 +1,6 @@ -from configs.datasets.dummy_dataset import dataset_config -from configs.models.default_model import model_config -from configs.training.default_training import training_config +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config from keras.utils import Config config = Config() diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index 552f25d0..4e3904e2 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -1,6 +1,6 @@ -from configs.datasets.dummy_dataset import dataset_config -from configs.models.default_model import model_config -from configs.training.default_training import training_config +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config from keras.utils import Config config = Config() From 0a31e0fb5f34608ce4a359dd16c9e5dd8910df03 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:47:45 +0530 Subject: [PATCH 008/279] Fix imports --- examples/ml_perf/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 96412c81..078e63f3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -162,9 +162,11 @@ def generator(dataset, training=False): args = parser.parse_args() print(f"===== Reading config from {args.config_name} ======") - config = getattr( - importlib.import_module(f".configs.{args.config_name}", package=__package__) - ).config + config = ( + importlib.import_module( + f".configs.{args.config_name}", package=__package__ + ).config + ) # === Unpack args from config === From 53e9c2304588cdba2cf26bb915f3e93b7ada7fb2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:48:47 +0530 Subject: [PATCH 009/279] Fix num_processes --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 078e63f3..9abcb718 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -37,7 +37,7 @@ def main( devices = keras.distribution.list_devices(device_type="tpu") distribution = keras.distribution.DataParallel(devices=devices) keras.distribution.set_distribution(distribution) - num_processes = distribution._num_process() + num_processes = distribution._num_process per_host_batch_size = global_batch_size // num_processes From f22b5ffdaf343daa3e3d33973311ec2c6b4977fb Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 13:50:13 +0530 Subject: [PATCH 010/279] Add bash script --- examples/ml_perf/configs/v6e_16.py | 3 +- examples/ml_perf/configs/v6e_8.py | 3 +- examples/ml_perf/main.py | 8 +-- examples/ml_perf/model.py | 4 +- examples/ml_perf/run.sh | 106 +++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 examples/ml_perf/run.sh diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index c9c50aff..4b6df8df 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -1,7 +1,8 @@ +from keras.utils import Config + from .datasets.dummy_dataset import dataset_config from .models.default_model import model_config from .training.default_training import training_config -from keras.utils import Config config = Config() diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index 4e3904e2..fcd81e39 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -1,7 +1,8 @@ +from keras.utils import Config + from .datasets.dummy_dataset import dataset_config from .models.default_model import model_config from .training.default_training import training_config -from keras.utils import Config config = Config() diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 9abcb718..5bea8c72 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -162,11 +162,9 @@ def generator(dataset, training=False): args = parser.parse_args() print(f"===== Reading config from {args.config_name} ======") - config = ( - importlib.import_module( - f".configs.{args.config_name}", package=__package__ - ).config - ) + config = importlib.import_module( + f".configs.{args.config_name}", package=__package__ + ).config # === Unpack args from config === diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 76cf3d09..a9f01f36 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -117,7 +117,9 @@ def __init__( keras.layers.Embedding( input_dim=small_emb_feature["vocabulary_size"], output_dim=embedding_dim, - embeddings_initializer="zeros", + embeddings_initializer=keras.initializers.LecunNormal( + seed=self.seed, + ), name=f"small_embedding_layer_{i}", ) for i, small_emb_feature in enumerate(small_emb_features) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh new file mode 100644 index 00000000..f5d840a5 --- /dev/null +++ b/examples/ml_perf/run.sh @@ -0,0 +1,106 @@ +#!/bin/bash + +# ============================================================================== +# Script Configuration & Argument Handling +# ============================================================================== +# This script accepts up to three optional arguments: +# 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) +# 2. Zone (default: us-east5-a) +# 3. Project (default: tpu-prod-env-one-vm) + +ACCELERATOR_TYPE=${1:-"v6e-8"} +ZONE=${2:-"us-east5-a"} +PROJECT=${3:-"tpu-prod-env-one-vm"} + +# Validate the provided accelerator type +if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then + echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 + echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project]" >&2 + exit 1 +fi + +# ============================================================================== +# Environment Variables +# ============================================================================== +# TPU name is generated dynamically. Zone and Project are set from args or defaults. +export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" +export ZONE +export PROJECT + +echo ">>> Using Configuration:" +echo " Accelerator: ${ACCELERATOR_TYPE}" +echo " TPU Name: ${TPU_NAME}" +echo " Zone: ${ZONE}" +echo " Project: ${PROJECT}" +echo "--------------------------------------------------" + + +# ============================================================================== +# TPU VM Creation +# ============================================================================== +echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." +gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --version=v2-alpha-tpuv6e \ + --project=${PROJECT} \ + --metadata=enable-oslogin=TRUE \ + --scopes=https://www.googleapis.com/auth/cloud-platform + + +# ============================================================================== +# Setup Python Virtual Environment on all workers +# ============================================================================== +echo ">>> Creating Python virtual environment..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="sudo apt-get update && sudo apt install -y python3.10-venv && python3 -m venv .keras-env" + + +# ============================================================================== +# Clone KerasRS and Install Dependencies +# ============================================================================== +echo ">>> Cloning KerasRS and installing dependencies..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && git clone https://github.com/abheesht17/keras-rs.git && cd keras-rs && git checkout ml-perf && pip install -e . && pip install tensorflow-datasets && pip uninstall -y tensorflow keras && pip install git+https://github.com/keras-team/keras.git && pip install jax-tpu-embedding tensorflow-cpu" + + +# ============================================================================== +# Install TPU-compatible JAX +# ============================================================================== +echo ">>> Re-installing JAX for TPU compatibility..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && pip uninstall -y jax jaxlib && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" + + +# ============================================================================== +# Verify Installation +# ============================================================================== +echo ">>> Verifying JAX installation..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python script.py" + + +# ============================================================================== +# Run Training Script +# ============================================================================== +# The config path is now also set dynamically. +echo ">>> Running the main script with config for ${ACCELERATOR_TYPE}..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${ACCELERATOR_TYPE}" + +echo ">>> Script finished." From 41f297753c6f45d807d654ac8857c9dc771da20f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 13:51:31 +0530 Subject: [PATCH 011/279] Add bash script --- examples/ml_perf/run.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index f5d840a5..d84bd7d1 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -22,7 +22,6 @@ fi # ============================================================================== # Environment Variables # ============================================================================== -# TPU name is generated dynamically. Zone and Project are set from args or defaults. export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" export ZONE export PROJECT @@ -95,7 +94,6 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ # ============================================================================== # Run Training Script # ============================================================================== -# The config path is now also set dynamically. echo ">>> Running the main script with config for ${ACCELERATOR_TYPE}..." gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ From 09ca14c381bfc8e1ba3c54faf639eb54137d8b5a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 14:02:12 +0530 Subject: [PATCH 012/279] Modify bash script to take in config name --- examples/ml_perf/run.sh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index d84bd7d1..b34953f2 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -7,10 +7,12 @@ # 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) # 2. Zone (default: us-east5-a) # 3. Project (default: tpu-prod-env-one-vm) +# 4. Config Name (default: derived from accelerator type, e.g., v6e_8) ACCELERATOR_TYPE=${1:-"v6e-8"} ZONE=${2:-"us-east5-a"} PROJECT=${3:-"tpu-prod-env-one-vm"} +USER_CONFIG_NAME=${4} # Capture the fourth argument # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then @@ -26,11 +28,19 @@ export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" export ZONE export PROJECT +# Use the user-provided config name if it exists, otherwise derive it. +if [[ -n "${USER_CONFIG_NAME}" ]]; then + export CONFIG_NAME=${USER_CONFIG_NAME} +else + export CONFIG_NAME=${ACCELERATOR_TYPE//-/_} +fi + echo ">>> Using Configuration:" echo " Accelerator: ${ACCELERATOR_TYPE}" echo " TPU Name: ${TPU_NAME}" echo " Zone: ${ZONE}" echo " Project: ${PROJECT}" +echo " Config Name: ${CONFIG_NAME}" echo "--------------------------------------------------" @@ -99,6 +109,6 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${ACCELERATOR_TYPE}" + --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" echo ">>> Script finished." From 2a1c759251adc0e19c58592da66a56d3447a3341 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 19:54:33 +0530 Subject: [PATCH 013/279] Add way to load real dataset --- .../ml_perf/configs/datasets/dummy_dataset.py | 26 ++ examples/ml_perf/dataloader.py | 237 +++++++++++++----- examples/ml_perf/main.py | 16 +- examples/ml_perf/run.sh | 3 +- 4 files changed, 222 insertions(+), 60 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 469ad206..0cbda44f 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -11,130 +11,156 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "multi_hot_size": 3, + "new_name": "cat_14_id", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "multi_hot_size": 2, + "new_name": "cat_15_id", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "multi_hot_size": 1, + "new_name": "cat_16_id", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "multi_hot_size": 2, + "new_name": "cat_17_id", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "multi_hot_size": 6, + "new_name": "cat_18_id", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "multi_hot_size": 1, + "new_name": "cat_19_id", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "multi_hot_size": 1, + "new_name": "cat_20_id", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "multi_hot_size": 1, + "new_name": "cat_21_id", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "multi_hot_size": 1, + "new_name": "cat_22_id", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "multi_hot_size": 7, + "new_name": "cat_23_id", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "multi_hot_size": 3, + "new_name": "cat_24_id", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "multi_hot_size": 8, + "new_name": "cat_25_id", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "multi_hot_size": 1, + "new_name": "cat_26_id", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "multi_hot_size": 6, + "new_name": "cat_27_id", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "multi_hot_size": 9, + "new_name": "cat_28_id", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "multi_hot_size": 5, + "new_name": "cat_29_id", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "multi_hot_size": 1, + "new_name": "cat_30_id", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "multi_hot_size": 1, + "new_name": "cat_31_id", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "multi_hot_size": 1, + "new_name": "cat_32_id", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "multi_hot_size": 12, + "new_name": "cat_33_id", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "multi_hot_size": 100, + "new_name": "cat_34_id", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "multi_hot_size": 27, + "new_name": "cat_35_id", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "multi_hot_size": 10, + "new_name": "cat_36_id", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "multi_hot_size": 3, + "new_name": "cat_37_id", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "multi_hot_size": 1, + "new_name": "cat_38_id", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "multi_hot_size": 1, + "new_name": "cat_39_id", }, ] diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 467f390a..2b34f352 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -2,68 +2,191 @@ import tensorflow as tf -def _get_dummy_batch(batch_size, large_emb_features, small_emb_features): - """Returns a dummy batch of data in the final desired structure.""" - - # Labels - data = { - "clicked": np.random.randint(0, 2, size=(batch_size,), dtype=np.int64) - } - - # Dense features - dense_input_list = [ - np.random.uniform(0.0, 0.9, size=(batch_size, 1)).astype(np.float32) - for _ in range(13) - ] - data["dense_input"] = np.concatenate(dense_input_list, axis=-1) - - # Sparse features - large_emb_inputs = {} - for large_emb_feature in large_emb_features: - vocabulary_size = large_emb_feature["vocabulary_size"] - multi_hot_size = large_emb_feature["multi_hot_size"] - idx = large_emb_feature["name"].split("-")[-1] - - large_emb_inputs[f"cat_{idx}_id"] = np.random.randint( - low=0, - high=vocabulary_size, - size=(batch_size, multi_hot_size), - dtype=np.int64, - ) +class DataLoader: + def __init__( + self, + file_pattern, + batch_size, + dense_features, + large_emb_features, + small_emb_features, + label, + training=False, + ): + # Passed attributes. + self.file_pattern = file_pattern + self.batch_size = batch_size + self.dense_features = dense_features + self.large_emb_features = large_emb_features + self.small_emb_features = small_emb_features + self.label = label + self.training = training + + # Derived attributes. + self._return_dummy_dataset = file_pattern is None + + def _get_dummy_batch(self): + """Returns a dummy batch of data in the final desired structure.""" + + # Labels + data = { + "clicked": np.random.randint( + 0, 2, size=(self.batch_size,), dtype=np.int64 + ) + } + + # Dense features + dense_input_list = [ + np.random.uniform(0.0, 0.9, size=(self.batch_size, 1)).astype( + np.float32 + ) + for _ in range(13) + ] + data["dense_input"] = np.concatenate(dense_input_list, axis=-1) + + # Sparse features + large_emb_inputs = {} + for large_emb_feature in self.large_emb_features: + name = large_emb_feature["name"] + new_name = large_emb_feature.get("new_name", name) + vocabulary_size = large_emb_feature["vocabulary_size"] + multi_hot_size = large_emb_feature["multi_hot_size"] + + large_emb_inputs[new_name] = np.random.randint( + low=0, + high=vocabulary_size, + size=(self.batch_size, multi_hot_size), + dtype=np.int64, + ) + + data["large_emb_inputs"] = large_emb_inputs + + # Dense lookup features + small_emb_inputs = {} + for small_emb_feature in self.small_emb_features: + name = small_emb_feature["name"] + new_name = small_emb_feature.get("new_name", name) + vocabulary_size = small_emb_feature["vocabulary_size"] + multi_hot_size = small_emb_feature["multi_hot_size"] + + small_emb_inputs[new_name] = np.random.randint( + low=0, + high=vocabulary_size, + size=(self.batch_size, multi_hot_size), + dtype=np.int64, + ) + + if small_emb_inputs: + data["small_emb_inputs"] = small_emb_inputs + + return data + + def _create_dummy_dataset(self): + """Creates a TF dummy dataset (randomly initialised).""" + dummy_data = self._get_dummy_batch() + + # Separate labels from features to create a `(features, labels)` tuple. + labels = dummy_data.pop("clicked") + features = dummy_data + + dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) + return dataset - data["large_emb_inputs"] = large_emb_inputs - - # Dense lookup features - small_emb_inputs = {} - for small_emb_feature in small_emb_features: - vocabulary_size = small_emb_feature["vocabulary_size"] - multi_hot_size = small_emb_feature["multi_hot_size"] - idx = small_emb_feature["name"].split("-")[-1] - - # TODO: We don't need this custom renaming. Remove later, when we - # shift from dummy data to actual data. - small_emb_inputs[f"cat_{idx}_id"] = np.random.randint( - low=0, - high=vocabulary_size, - size=(batch_size, multi_hot_size), - dtype=np.int64, + def _get_feature_spec(self): + feature_spec = { + self.label: tf.io.FixedLenFeature( + [self.batch_size], + dtype=tf.int64, + ) + } + + for dense_feat in self.dense_features: + feature_spec[dense_feat] = tf.io.FixedLenFeature( + [self.batch_size], + dtype=tf.float32, + ) + + for emb_feat in self.large_emb_features + self.small_emb_features: + name = emb_feat["name"] + feature_spec[name] = tf.io.FixedLenFeature( + [self.batch_size], + dtype=tf.string, + ) + + return feature_spec + + def _preprocess(self, example): + # Read example. + feature_spec = self.get_feature_spec() + example = tf.io.parse_single_example(example, feature_spec) + + # Dense features + dense_input = tf.stack( + [ + tf.reshape(example[dense_feature], [self.batch_size, 1]) + for dense_feature in self.dense_features + ], + axis=-1, ) - if small_emb_inputs: - data["small_emb_inputs"] = small_emb_inputs + def _get_emb_inputs(emb_features): + emb_inputs = {} + for emb_feature in emb_features: + name = emb_feature["name"] + new_name = emb_feature.get("new_name", name) + multi_hot_size = emb_feature["multi_hot_size"] + + raw_values = tf.io.decode_raw(example[name], tf.int64) + raw_values = tf.reshape( + raw_values, [self.batch_size, multi_hot_size] + ) + emb_inputs[new_name] = raw_values + return emb_inputs + + # Sparse features + large_emb_inputs = _get_emb_inputs(self.large_emb_features) + small_emb_inputs = _get_emb_inputs(self.small_emb_features) + + # Labels + labels = tf.reshape(example[self.label], [self.batch_size]) - return data + x = { + "dense_input": dense_input, + "large_emb_inputs": large_emb_inputs, + } + if small_emb_inputs: + x["small_emb_inputs"] = small_emb_inputs + return (x, labels) + + def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): + if self._return_dummy_dataset: + return self._create_dummy_dataset() + + dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) + + # Shard the dataset across hosts/workers. + # TODO: Do we need to do this if we are distributing the dataset + # manually using distribution.distribute_dataset(...)? + if num_processes > 1: + dataset = dataset.shard(num_processes, process_id) + + dataset = tf.data.TFRecordDataset( + dataset, + buffer_size=None, + num_parallel_reads=tf.data.AUTOTUNE, + ) + + # Process example. + dataset = dataset.map( + lambda x: self._preprocess(x), + num_parallel_calls=tf.data.AUTOTUNE, + ) -def create_dummy_dataset(batch_size, large_emb_features, small_emb_features): - """Creates a TF dataset from cached dummy data of the final batch size.""" - dummy_data = _get_dummy_batch( - batch_size, large_emb_features, small_emb_features - ) + # Shuffle dataset if in training mode. + if self.training and shuffle_buffer and shuffle_buffer > 0: + dataset = dataset.shuffle(shuffle_buffer) - # Separate labels from features to create a `(features, labels)` tuple. - labels = dummy_data.pop("clicked") - features = dummy_data + dataset = dataset.prefetch(tf.data.AUTOTUNE) - dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) - return dataset + return dataset diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 5bea8c72..da7e35f8 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -8,7 +8,7 @@ import keras_rs -from .dataloader import create_dummy_dataset +from .dataloader import DataLoader from .model import DLRMDCNV2 SEED = 1337 @@ -20,6 +20,7 @@ def main( large_emb_features, small_emb_features, label, + shuffle_buffer, embedding_dim, allow_id_dropping, max_ids_per_partition, @@ -109,10 +110,18 @@ def main( # === Load dataset === print("===== Loading dataset =====") - train_ds = create_dummy_dataset( + train_ds = DataLoader( + file_pattern=file_pattern, batch_size=per_host_batch_size, + dense_features=dense_features, large_emb_features=large_emb_features, small_emb_features=small_emb_features, + label=label, + training=True, + ).create_dataset( + process_id=distribution._process_id, + num_processes=num_processes, + shuffle_buffer=shuffle_buffer, ) # For the multi-host case, the dataset has to be distributed manually. # See note here: @@ -172,6 +181,8 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] + # Shuffling + shuffle_buffer = ds_cfg["shuffle_buffer"] # Features label = ds_cfg["label"] dense_features = ds_cfg["dense"] @@ -219,6 +230,7 @@ def generator(dataset, training=False): large_emb_features, small_emb_features, label, + shuffle_buffer, embedding_dim, allow_id_dropping, max_ids_per_partition, diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index b34953f2..9c7a6784 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -euo pipefail # ============================================================================== # Script Configuration & Argument Handling @@ -12,7 +13,7 @@ ACCELERATOR_TYPE=${1:-"v6e-8"} ZONE=${2:-"us-east5-a"} PROJECT=${3:-"tpu-prod-env-one-vm"} -USER_CONFIG_NAME=${4} # Capture the fourth argument +USER_CONFIG_NAME=${4} # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then From b50a7f5a25402030727f34ad4083816d9bc8d73e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 20:09:00 +0530 Subject: [PATCH 014/279] Add way to load real dataset (1) --- examples/ml_perf/main.py | 4 +-- examples/ml_perf/run.sh | 60 ++++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index da7e35f8..2ae6a078 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -182,7 +182,7 @@ def generator(dataset, training=False): # File path file_pattern = ds_cfg["file_pattern"] # Shuffling - shuffle_buffer = ds_cfg["shuffle_buffer"] + shuffle_buffer = ds_cfg.get("shuffle_buffer", None) # Features label = ds_cfg["label"] dense_features = ds_cfg["dense"] @@ -211,7 +211,7 @@ def generator(dataset, training=False): num_epochs = training_cfg["num_epochs"] # For features which have vocabulary_size < embedding_threshold, we can - # just do a normal dense lookup for those instead of have distributed + # just do a normal dense lookup for those instead of having distributed # embeddings. We could ideally pass `placement = default_device` to # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this # separation of features), but doing it that way will necessarily require diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index 9c7a6784..aebf840a 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -4,7 +4,7 @@ set -euo pipefail # ============================================================================== # Script Configuration & Argument Handling # ============================================================================== -# This script accepts up to three optional arguments: +# This script accepts up to four optional arguments: # 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) # 2. Zone (default: us-east5-a) # 3. Project (default: tpu-prod-env-one-vm) @@ -13,12 +13,12 @@ set -euo pipefail ACCELERATOR_TYPE=${1:-"v6e-8"} ZONE=${2:-"us-east5-a"} PROJECT=${3:-"tpu-prod-env-one-vm"} -USER_CONFIG_NAME=${4} +USER_CONFIG_NAME=${4:-""} # Initialize with an empty string if not provided # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 - echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project]" >&2 + echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project] [config_name]" >&2 exit 1 fi @@ -48,36 +48,62 @@ echo "--------------------------------------------------" # ============================================================================== # TPU VM Creation # ============================================================================== -echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." -gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ - --zone=${ZONE} \ - --accelerator-type=${ACCELERATOR_TYPE} \ - --version=v2-alpha-tpuv6e \ - --project=${PROJECT} \ - --metadata=enable-oslogin=TRUE \ - --scopes=https://www.googleapis.com/auth/cloud-platform +echo ">>> Checking for existing TPU VM: ${TPU_NAME}..." +if gcloud alpha compute tpus tpu-vm describe ${TPU_NAME} --zone=${ZONE} --project=${PROJECT} &> /dev/null; then + echo ">>> TPU VM '${TPU_NAME}' already exists. Skipping creation." +else + echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." + gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --version=v2-alpha-tpuv6e \ + --project=${PROJECT} \ + --metadata=enable-oslogin=TRUE \ + --scopes=https://www.googleapis.com/auth/cloud-platform +fi # ============================================================================== -# Setup Python Virtual Environment on all workers +# Setup Python venv on all workers # ============================================================================== -echo ">>> Creating Python virtual environment..." +echo ">>> Checking for Python virtual environment..." gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="sudo apt-get update && sudo apt install -y python3.10-venv && python3 -m venv .keras-env" + --command="sudo apt-get update && sudo apt install -y python3.10-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" # ============================================================================== -# Clone KerasRS and Install Dependencies +# Clone/Update KerasRS and Install Dependencies # ============================================================================== -echo ">>> Cloning KerasRS and installing dependencies..." +echo ">>> Cloning or updating KerasRS and installing dependencies..." gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && git clone https://github.com/abheesht17/keras-rs.git && cd keras-rs && git checkout ml-perf && pip install -e . && pip install tensorflow-datasets && pip uninstall -y tensorflow keras && pip install git+https://github.com/keras-team/keras.git && pip install jax-tpu-embedding tensorflow-cpu" + --command=" + set -e # Ensure script exits on error + source .keras-env/bin/activate + + if [ ! -d 'keras-rs' ]; then + echo '>>> Cloning keras-rs repository...' + git clone https://github.com/abheesht17/keras-rs.git + cd keras-rs + git checkout ml-perf + else + echo '>>> keras-rs repository exists. Pulling latest changes...' + cd keras-rs + git checkout ml-perf # Ensure we are on the correct branch + git pull + fi + + echo '>>> Installing/updating dependencies...' + pip install -e . + pip uninstall -y tensorflow keras + pip install git+https://github.com/keras-team/keras.git + pip install jax-tpu-embedding tensorflow-cpu + " # ============================================================================== From fe8dc414520457a8290eb129cf204791c34a9c3e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:09:57 +0530 Subject: [PATCH 015/279] Add dataset path --- .../ml_perf/configs/v6e_8_full_dataset.py | 20 ++++++++ examples/ml_perf/run.sh | 50 +++++++++++++++---- 2 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 examples/ml_perf/configs/v6e_8_full_dataset.py diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py new file mode 100644 index 00000000..71fd3a5a --- /dev/null +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -0,0 +1,20 @@ +from keras.utils import Config + +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config + +config = Config() + +config.experiment_name = "v6e_8_full_dataset" +config.model_dir = "./v6e_8_full_dataset" + +config.dataset = dataset_config +config.dataset.file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-00000-of-01024tfrecord" +) +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index aebf840a..122dd92c 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -5,20 +5,50 @@ set -euo pipefail # Script Configuration & Argument Handling # ============================================================================== # This script accepts up to four optional arguments: -# 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) -# 2. Zone (default: us-east5-a) -# 3. Project (default: tpu-prod-env-one-vm) -# 4. Config Name (default: derived from accelerator type, e.g., v6e_8) - -ACCELERATOR_TYPE=${1:-"v6e-8"} -ZONE=${2:-"us-east5-a"} -PROJECT=${3:-"tpu-prod-env-one-vm"} -USER_CONFIG_NAME=${4:-""} # Initialize with an empty string if not provided +# 1. --accelerator-type (default: v6e-8, options: v6e-8, v6e-16) +# 2. --zone (default: us-east5-a) +# 3. --project (default: tpu-prod-env-one-vm) +# 4. --config-name (default: derived from accelerator type, e.g., v6e_8) + +# Defaults +ACCELERATOR_TYPE="v6e-8" +ZONE="us-east5-a" +PROJECT="tpu-prod-env-one-vm" +USER_CONFIG_NAME="" + +# ============================================================================== +# Argument Parsing +# ============================================================================== + +show_help() { +cat << EOF +Usage: $0 [--accelerator-type ] [--zone ] [--project ] [--config-name ] +Options: + --accelerator-type The type of TPU accelerator (default: v6e-8). Options: v6e-8, v6e-16. + --zone The GCP zone for the TPU VM (default: us-east5-a). + --project The GCP project ID (default: tpu-prod-env-one-vm). + --config-name The specific configuration name to use for the training script. + (default: derived from accelerator type, e.g., v6e_8). + -h, --help Show this help message. +EOF +} + + +while [[ "$#" -gt 0 ]]; do + case $1 in + --accelerator-type) ACCELERATOR_TYPE="$2"; shift ;; + --zone) ZONE="$2"; shift ;; + --project) PROJECT="$2"; shift ;; + --config-name) USER_CONFIG_NAME="$2"; shift ;; + *) echo "Unknown parameter passed: $1"; show_help; exit 1 ;; + esac + shift +done # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 - echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project] [config_name]" >&2 + show_help exit 1 fi From ca297d45b25af03993b31d95edf7bdda8b74ecb4 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:12:57 +0530 Subject: [PATCH 016/279] Dataloader fixes (1) --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 2b34f352..89556b5f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -117,7 +117,7 @@ def _get_feature_spec(self): def _preprocess(self, example): # Read example. - feature_spec = self.get_feature_spec() + feature_spec = self._get_feature_spec() example = tf.io.parse_single_example(example, feature_spec) # Dense features From 35c3d61aa97a40b935c29c0f3dc4b15ef7b3ab82 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:16:59 +0530 Subject: [PATCH 017/279] Dataloader fixes (2) --- examples/ml_perf/configs/datasets/dummy_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 0cbda44f..418eedba 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -5,7 +5,7 @@ dataset_config.file_pattern = None # Features dataset_config.label = "clicked" -dataset_config.dense = [f"int-feature-{i}" for i in range(13)] +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] dataset_config.sparse = [ { "name": "categorical-feature-14", From 9cc9b88fb7071b2fd081cf4bfdb4424e838ff665 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:21:52 +0530 Subject: [PATCH 018/279] Dataloader fixes (3) --- examples/ml_perf/configs/v6e_8_full_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index 71fd3a5a..ef0f5347 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -16,5 +16,6 @@ ) config.model = model_config config.training = training_config +config.training.batch_size = 4224 config.freeze() From a0431ba0a43182823bf1feb50b405fda01dc71d2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 05:56:23 +0530 Subject: [PATCH 019/279] Feature naming edit --- examples/ml_perf/main.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2ae6a078..28545a42 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -47,11 +47,7 @@ def main( for large_emb_feature in large_emb_features: # Rename these features to something shorter; was facing some weird # issues with the longer names. - feature_name = ( - large_emb_feature["name"] - .replace("-", "_") - .replace("egorical_feature", "") - ) + feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] multi_hot_size = large_emb_feature["multi_hot_size"] From 2b9538f79ad8f3961b394022a36a7f57ab232e95 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 05:56:55 +0530 Subject: [PATCH 020/279] Feature naming edit --- examples/ml_perf/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 28545a42..11d6c8d1 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -45,8 +45,6 @@ def main( # === Distributed embeddings' configs for sparse features === feature_configs = {} for large_emb_feature in large_emb_features: - # Rename these features to something shorter; was facing some weird - # issues with the longer names. feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] multi_hot_size = large_emb_feature["multi_hot_size"] From 187ccd50a962d31ffdf9ca9df5bbb44b96480739 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 06:03:52 +0530 Subject: [PATCH 021/279] Feature naming edit --- examples/ml_perf/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index a9f01f36..713e02c1 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -166,8 +166,9 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: dense_output = self.bottom_mlp(dense_input) # jax.debug.print("dense_ouput {}", dense_output.shape) large_embeddings = self.embedding_layer(large_emb_inputs) - small_embeddings = [] + small_embeddings = None if self.small_emb_features: + small_embeddings = [] small_emb_inputs = inputs["small_emb_inputs"] for small_emb_input, embedding_layer in zip( small_emb_inputs.values(), self.small_embedding_layers @@ -179,11 +180,10 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction - x = ops.concatenate( - [dense_output, small_embeddings, *large_embeddings.values()], - axis=-1, - ) - # jax.debug.print("x {}", x.shape) + to_concatenate = [dense_output, *large_embeddings.values()] + if small_embeddings is not None: + to_concatenate += [small_embeddings] + x = ops.concatenate(to_concatenate, axis=-1) x = self.dcn_block(x) # Predictions From a98d431649a31102713ae46b6afed8ff373893af Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 06:04:32 +0530 Subject: [PATCH 022/279] Feature naming edit --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 11d6c8d1..8a01a8c3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -147,7 +147,7 @@ def generator(dataset, training=False): break # Train the model. - model.fit(train_generator, epochs=1) + model.fit(train_generator, epochs=num_epochs) if __name__ == "__main__": From d15957d74937eea5dd1d7dcfddc399c4f2e3391d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 09:25:00 +0530 Subject: [PATCH 023/279] Actual dataset loading fixes (1) --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 89556b5f..a085ab18 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -96,7 +96,7 @@ def _get_feature_spec(self): feature_spec = { self.label: tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.int64, + dtype=tf.float32, ) } From 8870c8df77acff1498880e6c7f3213502eb96c9c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 09:42:18 +0530 Subject: [PATCH 024/279] Fix feature spec dtypes --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index a085ab18..0184ee56 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -110,7 +110,7 @@ def _get_feature_spec(self): name = emb_feat["name"] feature_spec[name] = tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.string, + dtype=tf.int64, ) return feature_spec From e971f190e196fda777ad860521672f17c6d4091b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 10:24:09 +0530 Subject: [PATCH 025/279] Fix feature spec dtypes --- examples/ml_perf/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 0184ee56..89556b5f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -96,7 +96,7 @@ def _get_feature_spec(self): feature_spec = { self.label: tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.float32, + dtype=tf.int64, ) } @@ -110,7 +110,7 @@ def _get_feature_spec(self): name = emb_feat["name"] feature_spec[name] = tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.int64, + dtype=tf.string, ) return feature_spec From 02c1881a7d6cac530e0d17fe14846e9c272abde2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 11:13:45 +0530 Subject: [PATCH 026/279] Allow different batch sizes from file batch size --- .../ml_perf/configs/v6e_8_full_dataset.py | 4 +++- examples/ml_perf/dataloader.py | 24 ++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index ef0f5347..eca6c7ed 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -14,8 +14,10 @@ "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-00000-of-01024tfrecord" ) +# The path which we are reading from already has the batched dataset. +config.dataset.file_batch_size = 4224 config.model = model_config config.training = training_config -config.training.batch_size = 4224 +config.training.batch_size = 256 config.freeze() diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 89556b5f..bfa1a963 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -7,6 +7,7 @@ def __init__( self, file_pattern, batch_size, + file_batch_size, dense_features, large_emb_features, small_emb_features, @@ -16,6 +17,7 @@ def __init__( # Passed attributes. self.file_pattern = file_pattern self.batch_size = batch_size + self.file_batch_size = file_batch_size self.dense_features = dense_features self.large_emb_features = large_emb_features self.small_emb_features = small_emb_features @@ -95,21 +97,21 @@ def _create_dummy_dataset(self): def _get_feature_spec(self): feature_spec = { self.label: tf.io.FixedLenFeature( - [self.batch_size], + [self.file_batch_size], dtype=tf.int64, ) } for dense_feat in self.dense_features: feature_spec[dense_feat] = tf.io.FixedLenFeature( - [self.batch_size], + [self.file_batch_size], dtype=tf.float32, ) for emb_feat in self.large_emb_features + self.small_emb_features: name = emb_feat["name"] feature_spec[name] = tf.io.FixedLenFeature( - [self.batch_size], + [self.file_batch_size], dtype=tf.string, ) @@ -123,7 +125,7 @@ def _preprocess(self, example): # Dense features dense_input = tf.stack( [ - tf.reshape(example[dense_feature], [self.batch_size, 1]) + tf.reshape(example[dense_feature], [self.file_batch_size, 1]) for dense_feature in self.dense_features ], axis=-1, @@ -138,7 +140,7 @@ def _get_emb_inputs(emb_features): raw_values = tf.io.decode_raw(example[name], tf.int64) raw_values = tf.reshape( - raw_values, [self.batch_size, multi_hot_size] + raw_values, [self.file_batch_size, multi_hot_size] ) emb_inputs[new_name] = raw_values return emb_inputs @@ -148,7 +150,7 @@ def _get_emb_inputs(emb_features): small_emb_inputs = _get_emb_inputs(self.small_emb_features) # Labels - labels = tf.reshape(example[self.label], [self.batch_size]) + labels = tf.reshape(example[self.label], [self.file_batch_size]) x = { "dense_input": dense_input, @@ -179,14 +181,20 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Process example. dataset = dataset.map( - lambda x: self._preprocess(x), - num_parallel_calls=tf.data.AUTOTUNE, + self._preprocess, num_parallel_calls=tf.data.AUTOTUNE ) + dataset.unbatch() # Shuffle dataset if in training mode. if self.training and shuffle_buffer and shuffle_buffer > 0: dataset = dataset.shuffle(shuffle_buffer) + dataset = dataset.batch( + self.batch_size, + drop_remainder=True, + num_parallel_calls=tf.data.AUTOTUNE, + ) + dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset From b5db3043ffe9a7d9e36d1c6dd238dcc476de68fe Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 11:18:16 +0530 Subject: [PATCH 027/279] Allow different batch sizes from file batch size (fixes) --- examples/ml_perf/main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8a01a8c3..102fac26 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -32,6 +32,7 @@ def main( dcn_projection_dim, learning_rate, global_batch_size, + file_batch_size, num_epochs, ): # Set DDP as Keras distribution strategy @@ -107,6 +108,7 @@ def main( train_ds = DataLoader( file_pattern=file_pattern, batch_size=per_host_batch_size, + file_batch_size=file_batch_size, dense_features=dense_features, large_emb_features=large_emb_features, small_emb_features=small_emb_features, @@ -175,6 +177,8 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] + # File batch size + file_batch_size = ds_cfg.get("file_batch_size", None) # Shuffling shuffle_buffer = ds_cfg.get("shuffle_buffer", None) # Features @@ -236,5 +240,6 @@ def generator(dataset, training=False): dcn_projection_dim, learning_rate, global_batch_size, + file_batch_size, num_epochs, ) From e98bdf94177a4f04384ebc940fd6ee1015366d60 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 11:27:44 +0530 Subject: [PATCH 028/279] Fix feature naming --- .../ml_perf/configs/datasets/dummy_dataset.py | 52 +++++++++---------- examples/ml_perf/dataloader.py | 6 +-- examples/ml_perf/main.py | 2 +- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 418eedba..411b3b41 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -11,156 +11,156 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "multi_hot_size": 3, - "new_name": "cat_14_id", + "new_name": "cat_14", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "multi_hot_size": 2, - "new_name": "cat_15_id", + "new_name": "cat_15", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "multi_hot_size": 1, - "new_name": "cat_16_id", + "new_name": "cat_16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "multi_hot_size": 2, - "new_name": "cat_17_id", + "new_name": "cat_17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "multi_hot_size": 6, - "new_name": "cat_18_id", + "new_name": "cat_18", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "multi_hot_size": 1, - "new_name": "cat_19_id", + "new_name": "cat_19", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "multi_hot_size": 1, - "new_name": "cat_20_id", + "new_name": "cat_20", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "multi_hot_size": 1, - "new_name": "cat_21_id", + "new_name": "cat_21", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "multi_hot_size": 1, - "new_name": "cat_22_id", + "new_name": "cat_22", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "multi_hot_size": 7, - "new_name": "cat_23_id", + "new_name": "cat_23", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "multi_hot_size": 3, - "new_name": "cat_24_id", + "new_name": "cat_24", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "multi_hot_size": 8, - "new_name": "cat_25_id", + "new_name": "cat_25", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "multi_hot_size": 1, - "new_name": "cat_26_id", + "new_name": "cat_26", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "multi_hot_size": 6, - "new_name": "cat_27_id", + "new_name": "cat_27", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "multi_hot_size": 9, - "new_name": "cat_28_id", + "new_name": "cat_28", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "multi_hot_size": 5, - "new_name": "cat_29_id", + "new_name": "cat_29", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "multi_hot_size": 1, - "new_name": "cat_30_id", + "new_name": "cat_30", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "multi_hot_size": 1, - "new_name": "cat_31_id", + "new_name": "cat_31", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "multi_hot_size": 1, - "new_name": "cat_32_id", + "new_name": "cat_32", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "multi_hot_size": 12, - "new_name": "cat_33_id", + "new_name": "cat_33", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "multi_hot_size": 100, - "new_name": "cat_34_id", + "new_name": "cat_34", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "multi_hot_size": 27, - "new_name": "cat_35_id", + "new_name": "cat_35", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "multi_hot_size": 10, - "new_name": "cat_36_id", + "new_name": "cat_36", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "multi_hot_size": 3, - "new_name": "cat_37_id", + "new_name": "cat_37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "multi_hot_size": 1, - "new_name": "cat_38_id", + "new_name": "cat_38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "multi_hot_size": 1, - "new_name": "cat_39_id", + "new_name": "cat_39", }, ] diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index bfa1a963..fdf836a1 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -54,7 +54,7 @@ def _get_dummy_batch(self): vocabulary_size = large_emb_feature["vocabulary_size"] multi_hot_size = large_emb_feature["multi_hot_size"] - large_emb_inputs[new_name] = np.random.randint( + large_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, size=(self.batch_size, multi_hot_size), @@ -71,7 +71,7 @@ def _get_dummy_batch(self): vocabulary_size = small_emb_feature["vocabulary_size"] multi_hot_size = small_emb_feature["multi_hot_size"] - small_emb_inputs[new_name] = np.random.randint( + small_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, size=(self.batch_size, multi_hot_size), @@ -142,7 +142,7 @@ def _get_emb_inputs(emb_features): raw_values = tf.reshape( raw_values, [self.file_batch_size, multi_hot_size] ) - emb_inputs[new_name] = raw_values + emb_inputs[f"{new_name}_id"] = raw_values return emb_inputs # Sparse features diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 102fac26..2658f98c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -72,7 +72,7 @@ def main( max_unique_ids_per_partition=max_unique_ids_per_partition, ) feature_configs[f"{feature_name}_id"] = keras_rs.layers.FeatureConfig( - name=feature_name.replace("id", ""), + name=feature_name, table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or # `(bsz, multi_hot_size)`. From 73ca47710afb3966c6ff9795277b00d2c34dc920 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 12:55:16 +0530 Subject: [PATCH 029/279] Fix batching --- examples/ml_perf/configs/v6e_8_full_dataset.py | 4 ++++ examples/ml_perf/dataloader.py | 2 +- examples/ml_perf/main.py | 5 +++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index eca6c7ed..8489b084 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -14,6 +14,10 @@ "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-00000-of-01024tfrecord" ) +config.dataset.val_file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-00000-of-01024tfrecord" +) # The path which we are reading from already has the batched dataset. config.dataset.file_batch_size = 4224 config.model = model_config diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index fdf836a1..5a14257b 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -183,7 +183,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.map( self._preprocess, num_parallel_calls=tf.data.AUTOTUNE ) - dataset.unbatch() + dataset = dataset.unbatch() # Shuffle dataset if in training mode. if self.training and shuffle_buffer and shuffle_buffer > 0: diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2658f98c..ce12751b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -16,6 +16,7 @@ def main( file_pattern, + val_file_pattern, dense_features, large_emb_features, small_emb_features, @@ -124,6 +125,7 @@ def main( # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: train_ds = distribution.distribute_dataset(train_ds) + # eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False def generator(dataset, training=False): @@ -144,6 +146,7 @@ def generator(dataset, training=False): yield (x, y) train_generator = generator(train_ds, training=True) + # eval_generator = generator(eval_ds, training=False) for first_batch in train_generator: model(first_batch[0]) break @@ -177,6 +180,7 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] + val_file_pattern = ds_cfg("val_file_pattern", None) # File batch size file_batch_size = ds_cfg.get("file_batch_size", None) # Shuffling @@ -224,6 +228,7 @@ def generator(dataset, training=False): main( file_pattern, + val_file_pattern, dense_features, large_emb_features, small_emb_features, From c2ad8a971a515fd62eb3d28f87c27c9f8c1bbb6a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 12:59:52 +0530 Subject: [PATCH 030/279] Fix batching --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index ce12751b..e7a5f735 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -180,7 +180,7 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] - val_file_pattern = ds_cfg("val_file_pattern", None) + val_file_pattern = ds_cfg.get("val_file_pattern", None) # File batch size file_batch_size = ds_cfg.get("file_batch_size", None) # Shuffling From a0801439f3a9b602613ec6186c9840ef9a6298a3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 13:48:24 +0530 Subject: [PATCH 031/279] Fix dense features --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5a14257b..e0eafc7f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -123,7 +123,7 @@ def _preprocess(self, example): example = tf.io.parse_single_example(example, feature_spec) # Dense features - dense_input = tf.stack( + dense_input = tf.concatenate( [ tf.reshape(example[dense_feature], [self.file_batch_size, 1]) for dense_feature in self.dense_features From 28b7189f187d6dfbb7af73f34e7e5b2d2537337f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 13:51:32 +0530 Subject: [PATCH 032/279] Fix dense features concat --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index e0eafc7f..9f4df58d 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -123,7 +123,7 @@ def _preprocess(self, example): example = tf.io.parse_single_example(example, feature_spec) # Dense features - dense_input = tf.concatenate( + dense_input = tf.concat( [ tf.reshape(example[dense_feature], [self.file_batch_size, 1]) for dense_feature in self.dense_features From a47817d6eda032f385380092e90bc72a22418c1b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 14:21:01 +0530 Subject: [PATCH 033/279] Rename multi_hot_size to feature_list_length --- .../ml_perf/configs/datasets/dummy_dataset.py | 52 +++++++++---------- examples/ml_perf/dataloader.py | 12 ++--- examples/ml_perf/main.py | 6 +-- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 411b3b41..510f8b62 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -10,157 +10,157 @@ { "name": "categorical-feature-14", "vocabulary_size": 40000000, - "multi_hot_size": 3, + "feature_list_length": 3, "new_name": "cat_14", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, - "multi_hot_size": 2, + "feature_list_length": 2, "new_name": "cat_15", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, - "multi_hot_size": 2, + "feature_list_length": 2, "new_name": "cat_17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, - "multi_hot_size": 6, + "feature_list_length": 6, "new_name": "cat_18", }, { "name": "categorical-feature-19", "vocabulary_size": 3, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_19", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_20", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_21", }, { "name": "categorical-feature-22", "vocabulary_size": 63, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_22", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, - "multi_hot_size": 7, + "feature_list_length": 7, "new_name": "cat_23", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, - "multi_hot_size": 3, + "feature_list_length": 3, "new_name": "cat_24", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, - "multi_hot_size": 8, + "feature_list_length": 8, "new_name": "cat_25", }, { "name": "categorical-feature-26", "vocabulary_size": 10, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_26", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, - "multi_hot_size": 6, + "feature_list_length": 6, "new_name": "cat_27", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, - "multi_hot_size": 9, + "feature_list_length": 9, "new_name": "cat_28", }, { "name": "categorical-feature-29", "vocabulary_size": 155, - "multi_hot_size": 5, + "feature_list_length": 5, "new_name": "cat_29", }, { "name": "categorical-feature-30", "vocabulary_size": 4, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_30", }, { "name": "categorical-feature-31", "vocabulary_size": 976, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_31", }, { "name": "categorical-feature-32", "vocabulary_size": 14, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_32", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, - "multi_hot_size": 12, + "feature_list_length": 12, "new_name": "cat_33", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, - "multi_hot_size": 100, + "feature_list_length": 100, "new_name": "cat_34", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, - "multi_hot_size": 27, + "feature_list_length": 27, "new_name": "cat_35", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, - "multi_hot_size": 10, + "feature_list_length": 10, "new_name": "cat_36", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, - "multi_hot_size": 3, + "feature_list_length": 3, "new_name": "cat_37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_39", }, ] diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 9f4df58d..5d65c49c 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -52,12 +52,12 @@ def _get_dummy_batch(self): name = large_emb_feature["name"] new_name = large_emb_feature.get("new_name", name) vocabulary_size = large_emb_feature["vocabulary_size"] - multi_hot_size = large_emb_feature["multi_hot_size"] + feature_list_length = large_emb_feature["feature_list_length"] large_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, - size=(self.batch_size, multi_hot_size), + size=(self.batch_size, feature_list_length), dtype=np.int64, ) @@ -69,12 +69,12 @@ def _get_dummy_batch(self): name = small_emb_feature["name"] new_name = small_emb_feature.get("new_name", name) vocabulary_size = small_emb_feature["vocabulary_size"] - multi_hot_size = small_emb_feature["multi_hot_size"] + feature_list_length = small_emb_feature["feature_list_length"] small_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, - size=(self.batch_size, multi_hot_size), + size=(self.batch_size, feature_list_length), dtype=np.int64, ) @@ -136,11 +136,11 @@ def _get_emb_inputs(emb_features): for emb_feature in emb_features: name = emb_feature["name"] new_name = emb_feature.get("new_name", name) - multi_hot_size = emb_feature["multi_hot_size"] + feature_list_length = emb_feature["feature_list_length"] raw_values = tf.io.decode_raw(example[name], tf.int64) raw_values = tf.reshape( - raw_values, [self.file_batch_size, multi_hot_size] + raw_values, [self.file_batch_size, feature_list_length] ) emb_inputs[f"{new_name}_id"] = raw_values return emb_inputs diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index e7a5f735..a6bce733 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -49,7 +49,7 @@ def main( for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] - multi_hot_size = large_emb_feature["multi_hot_size"] + feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( name=f"{feature_name}_table", @@ -76,8 +76,8 @@ def main( name=feature_name, table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or - # `(bsz, multi_hot_size)`. - input_shape=(per_host_batch_size, multi_hot_size), + # `(bsz, feature_list_length)`. + input_shape=(per_host_batch_size, feature_list_length), output_shape=(per_host_batch_size, embedding_dim), ) From 9a33f091d31bdfc06acf8b6c60d4e5a85cb225d6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 14:24:35 +0530 Subject: [PATCH 034/279] Rename sparse to lookup --- examples/ml_perf/configs/datasets/dummy_dataset.py | 2 +- examples/ml_perf/dataloader.py | 6 +++--- examples/ml_perf/main.py | 12 ++++++------ examples/ml_perf/model.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 510f8b62..aac66c57 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -6,7 +6,7 @@ # Features dataset_config.label = "clicked" dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] -dataset_config.sparse = [ +dataset_config.lookup = [ { "name": "categorical-feature-14", "vocabulary_size": 40000000, diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5d65c49c..ce2e7286 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -46,7 +46,7 @@ def _get_dummy_batch(self): ] data["dense_input"] = np.concatenate(dense_input_list, axis=-1) - # Sparse features + # Big embedding features large_emb_inputs = {} for large_emb_feature in self.large_emb_features: name = large_emb_feature["name"] @@ -63,7 +63,7 @@ def _get_dummy_batch(self): data["large_emb_inputs"] = large_emb_inputs - # Dense lookup features + # Small embedding features small_emb_inputs = {} for small_emb_feature in self.small_emb_features: name = small_emb_feature["name"] @@ -145,7 +145,7 @@ def _get_emb_inputs(emb_features): emb_inputs[f"{new_name}_id"] = raw_values return emb_inputs - # Sparse features + # Embedding/lookup features large_emb_inputs = _get_emb_inputs(self.large_emb_features) small_emb_inputs = _get_emb_inputs(self.small_emb_features) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a6bce733..a0d2eac7 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -44,7 +44,7 @@ def main( per_host_batch_size = global_batch_size // num_processes - # === Distributed embeddings' configs for sparse features === + # === Distributed embeddings' configs for lookup features === feature_configs = {} for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] @@ -82,9 +82,9 @@ def main( ) # === Instantiate model === - # We instantiate the model first, because we need to preprocess sparse - # inputs using the distributed embedding layer defined inside the model - # class. + # We instantiate the model first, because we need to preprocess large + # embedding feature inputs using the distributed embedding layer defined + # inside the model class. print("===== Initialising model =====") model = DLRMDCNV2( large_emb_feature_configs=feature_configs, @@ -130,7 +130,7 @@ def main( def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses - sparse features. + large embedding features. """ for features, labels in dataset: preprocessed_large_embeddings = model.embedding_layer.preprocess( @@ -188,7 +188,7 @@ def generator(dataset, training=False): # Features label = ds_cfg["label"] dense_features = ds_cfg["dense"] - emb_features = ds_cfg["sparse"] + emb_features = ds_cfg["lookup"] # == Model config == model_cfg = config["model"] diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 713e02c1..4f84bbbf 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -56,7 +56,7 @@ def __init__( The model processes two types of input features: 1. Dense Features: Continuous-valued features that are processed by a multi-layer perceptron (the "bottom MLP"). - 2. Sparse Features: High-cardinality categorical features that are + 2. Lookup Features: High-cardinality categorical features that are first mapped into low-dimensional embedding vectors using the `keras_rs.layers.DistributedEmbedding` layer. This layer is highly optimized for large-scale recommendation models, especially on TPUs From a66e1c6fd15ecf12e7aa4f3db2a902bd4a2c54f1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:04:33 +0530 Subject: [PATCH 035/279] Debug --- examples/ml_perf/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a0d2eac7..2d7504e2 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -80,6 +80,8 @@ def main( input_shape=(per_host_batch_size, feature_list_length), output_shape=(per_host_batch_size, embedding_dim), ) + + print("-->", os.environ['XLA_FLAGS']) # === Instantiate model === # We instantiate the model first, because we need to preprocess large From a56532a196beaf88fffda3c9f6279b88e449da34 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:13:51 +0530 Subject: [PATCH 036/279] Try out XLA flags --- examples/ml_perf/main.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2d7504e2..43e06320 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -45,6 +45,13 @@ def main( per_host_batch_size = global_batch_size // num_processes # === Distributed embeddings' configs for lookup features === + # Set XLA flags. + os.environ['XLA_FLAGS'] = ( + "--xla_sparse_core_max_ids_per_partition_per_sample=" + f"{max_ids_per_partition} " + "--xla_sparse_core_max_unique_ids_per_partition_per_sample=" + f"{max_unique_ids_per_partition}" + ) feature_configs = {} for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] @@ -80,8 +87,6 @@ def main( input_shape=(per_host_batch_size, feature_list_length), output_shape=(per_host_batch_size, embedding_dim), ) - - print("-->", os.environ['XLA_FLAGS']) # === Instantiate model === # We instantiate the model first, because we need to preprocess large From 0a9d00b8df5667ff37094d603420c9b04e616c72 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:20:00 +0530 Subject: [PATCH 037/279] Format --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 43e06320..c085c1f0 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -46,7 +46,7 @@ def main( # === Distributed embeddings' configs for lookup features === # Set XLA flags. - os.environ['XLA_FLAGS'] = ( + os.environ["XLA_FLAGS"] = ( "--xla_sparse_core_max_ids_per_partition_per_sample=" f"{max_ids_per_partition} " "--xla_sparse_core_max_unique_ids_per_partition_per_sample=" From 42c40223e7ebd01d88dcc477884b6c5a869b301d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:23:32 +0530 Subject: [PATCH 038/279] Format --- examples/ml_perf/main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index c085c1f0..8dbb38cb 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -135,6 +135,10 @@ def main( # eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False + # Print one sample. + for element in train_ds.take(1): + print(">>> train sample", element[0]) + def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses large embedding features. From 6e76b59eb2f77f47763fbeb789aee74d3ee195a7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 3 Sep 2025 19:57:13 +0530 Subject: [PATCH 039/279] Copy over Antonio's fixes --- .../embedding/jax/distributed_embedding.py | 204 +++--------------- .../layers/embedding/jax/embedding_utils.py | 90 ++++++-- 2 files changed, 102 insertions(+), 192 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 2562a8be..72f504af 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -15,7 +15,6 @@ table_stacking as jte_table_stacking, ) from jax_tpu_embedding.sparsecore.utils import utils as jte_utils -from keras.src import backend from keras_rs.src import types from keras_rs.src.layers.embedding import base_distributed_embedding @@ -247,23 +246,6 @@ def _create_sparsecore_distribution( ) return sparsecore_distribution, sparsecore_layout - def _create_cpu_distribution( - self, cpu_axis_name: str = "cpu" - ) -> tuple[ - keras.distribution.ModelParallel, keras.distribution.TensorLayout - ]: - """Share a variable across all CPU processes.""" - cpu_devices = jax.devices("cpu") - device_mesh = keras.distribution.DeviceMesh( - (len(cpu_devices),), [cpu_axis_name], cpu_devices - ) - replicated_layout = keras.distribution.TensorLayout([], device_mesh) - layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh) - cpu_distribution = keras.distribution.ModelParallel( - layout_map=layout_map - ) - return cpu_distribution, replicated_layout - def _add_sparsecore_weight( self, name: str, @@ -405,11 +387,6 @@ def sparsecore_build( self._sparsecore_layout = sparsecore_layout self._sparsecore_distribution = sparsecore_distribution - # Distribution for CPU operations. - cpu_distribution, cpu_layout = self._create_cpu_distribution() - self._cpu_distribution = cpu_distribution - self._cpu_layout = cpu_layout - mesh = sparsecore_distribution.device_mesh.backend_mesh global_device_count = mesh.devices.size num_sc_per_device = jte_utils.num_sparsecores_per_device( @@ -466,10 +443,6 @@ def sparsecore_build( # Collect all stacked tables. table_specs = embedding_utils.get_table_specs(feature_specs) table_stacks = embedding_utils.get_table_stacks(table_specs) - stacked_table_specs = { - stack_name: stack[0].stacked_table_spec - for stack_name, stack in table_stacks.items() - } # Create variables for all stacked tables and slot variables. with sparsecore_distribution.scope(): @@ -502,50 +475,6 @@ def sparsecore_build( ) self._iterations.overwrite_with_gradient = True - with cpu_distribution.scope(): - # Create variables to track static buffer size and max IDs for each - # table during preprocessing. These variables are shared across all - # processes on CPU. We don't add these via `add_weight` because we - # can't have them passed to the training function. - replicated_zeros_initializer = ShardedInitializer( - "zeros", cpu_layout - ) - - with backend.name_scope(self.name, caller=self): - self._preprocessing_buffer_size = { - table_name: backend.Variable( - initializer=replicated_zeros_initializer, - shape=(), - dtype=backend.standardize_dtype("int32"), - trainable=False, - name=table_name + ":preprocessing:buffer_size", - ) - for table_name in stacked_table_specs.keys() - } - self._preprocessing_max_unique_ids_per_partition = { - table_name: backend.Variable( - shape=(), - name=table_name - + ":preprocessing:max_unique_ids_per_partition", - initializer=replicated_zeros_initializer, - dtype=backend.standardize_dtype("int32"), - trainable=False, - ) - for table_name in stacked_table_specs.keys() - } - - self._preprocessing_max_ids_per_partition = { - table_name: backend.Variable( - shape=(), - name=table_name - + ":preprocessing:max_ids_per_partition", - initializer=replicated_zeros_initializer, - dtype=backend.standardize_dtype("int32"), - trainable=False, - ) - for table_name in stacked_table_specs.keys() - } - self._config = jte_embedding_lookup.EmbeddingLookupConfiguration( feature_specs, mesh=mesh, @@ -660,76 +589,35 @@ def _sparsecore_preprocess( mesh.devices.item(0) ) - # Get current buffer size/max_ids. - previous_max_ids_per_partition = keras.tree.map_structure( - lambda max_ids_per_partition: max_ids_per_partition.value.item(), - self._preprocessing_max_ids_per_partition, - ) - previous_max_unique_ids_per_partition = keras.tree.map_structure( - lambda max_unique_ids_per_partition: ( - max_unique_ids_per_partition.value.item() - ), - self._preprocessing_max_unique_ids_per_partition, - ) - previous_buffer_size = keras.tree.map_structure( - lambda buffer_size: buffer_size.value.item(), - self._preprocessing_buffer_size, - ) - preprocessed, stats = embedding_utils.stack_and_shard_samples( self._config.feature_specs, samples, local_device_count, global_device_count, num_sc_per_device, - static_buffer_size=previous_buffer_size, ) - # Extract max unique IDs and buffer sizes. - # We need to replicate this value across all local CPU devices. if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + prev_stats = embedding_utils.get_stacked_table_stats( + self._config.feature_specs + ) + + # Take the maximum with existing stats. + stats = keras.tree.map_structure(max, prev_stats, stats) + + # Flatten the stats so we can more efficiently transfer them + # between hosts. We use jax.tree because we will later need to + # unflatten. + flat_stats, stats_treedef = jax.tree.flatten(stats) + + # In the case of multiple local CPU devices per host, we need to + # replicate the stats to placate JAX collectives. num_local_cpu_devices = jax.local_device_count("cpu") - local_max_ids_per_partition = { - table_name: np.repeat( - # Maximum across all partitions and previous max. - np.maximum( - np.max(elems), - previous_max_ids_per_partition[table_name], - ), - num_local_cpu_devices, - ) - for table_name, elems in stats.max_ids_per_partition.items() - } - local_max_unique_ids_per_partition = { - name: np.repeat( - # Maximum across all partitions and previous max. - np.maximum( - np.max(elems), - previous_max_unique_ids_per_partition[name], - ), - num_local_cpu_devices, - ) - for name, elems in stats.max_unique_ids_per_partition.items() - } - local_buffer_size = { - table_name: np.repeat( - np.maximum( - np.max( - # Round values up to the next multiple of 8. - # Currently using this as a proxy for the actual - # required buffer size. - ((elems + 7) // 8) * 8 - ) - * global_device_count - * num_sc_per_device - * local_device_count - * num_sc_per_device, - previous_buffer_size[table_name], - ), - num_local_cpu_devices, - ) - for table_name, elems in stats.max_ids_per_partition.items() - } + tiled_stats = np.tile( + np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1) + ) # Aggregate variables across all processes/devices. max_across_cpus = jax.pmap( @@ -737,48 +625,24 @@ def _sparsecore_preprocess( x, "all_cpus" ), axis_name="all_cpus", - devices=self._cpu_layout.device_mesh.backend_mesh.devices, - ) - new_max_ids_per_partition = max_across_cpus( - local_max_ids_per_partition - ) - new_max_unique_ids_per_partition = max_across_cpus( - local_max_unique_ids_per_partition + backend="cpu", ) - new_buffer_size = max_across_cpus(local_buffer_size) - - # Assign new preprocessing parameters. - with self._cpu_distribution.scope(): - # For each process, all max ids/buffer sizes are replicated - # across all local devices. Take the value from the first - # device. - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_max_ids_per_partition, - new_max_ids_per_partition, - ) - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_max_unique_ids_per_partition, - new_max_unique_ids_per_partition, - ) - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_buffer_size, - new_buffer_size, - ) - # Update parameters in the underlying feature specs. - int_max_ids_per_partition = keras.tree.map_structure( - lambda varray: varray.item(), new_max_ids_per_partition - ) - int_max_unique_ids_per_partition = keras.tree.map_structure( - lambda varray: varray.item(), - new_max_unique_ids_per_partition, + flat_stats = max_across_cpus(tiled_stats)[0].tolist() + stats = jax.tree.unflatten(stats_treedef, flat_stats) + + # Update configuration and repeat preprocessing if stats changed. + if stats != prev_stats: + embedding_utils.update_stacked_table_stats( + self._config.feature_specs, stats ) - embedding_utils.update_stacked_table_specs( + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( self._config.feature_specs, - int_max_ids_per_partition, - int_max_unique_ids_per_partition, + samples, + local_device_count, + global_device_count, + num_sc_per_device, ) return {"inputs": preprocessed} diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 393c197c..38e69f7d 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -35,6 +35,12 @@ class ShardedCooMatrix(NamedTuple): values: ArrayLike +class InputStatsPerTable(NamedTuple): + max_ids_per_partition: int + max_unique_ids_per_partition: int + required_buffer_size_per_device: int + + def _round_up_to_multiple(value: int, multiple: int) -> int: return ((value + multiple - 1) // multiple) * multiple @@ -335,19 +341,47 @@ def get_table_stacks( return stacked_table_specs -def update_stacked_table_specs( +def get_stacked_table_stats( feature_specs: Nested[FeatureSpec], - max_ids_per_partition: Mapping[str, int], - max_unique_ids_per_partition: Mapping[str, int], +) -> dict[str, InputStatsPerTable]: + """Extracts the stacked-table input statistics from the feature specs. + + Args: + feature_specs: Feature specs from which to extracts the statistics. + + Returns: + A mapping of stacked table names to input statistics per table. + """ + stacked_table_specs: dict[str, StackedTableSpec] = {} + for feature_spec in jax.tree.flatten(feature_specs)[0]: + feature_spec = typing.cast(FeatureSpec, feature_spec) + stacked_table_spec = typing.cast( + StackedTableSpec, feature_spec.table_spec.stacked_table_spec + ) + stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec + + stats: dict[str, InputStatsPerTable] = {} + for stacked_table_spec in stacked_table_specs.values(): + buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device + buffer_size = buffer_size or 0 + stats[stacked_table_spec.stack_name] = InputStatsPerTable( + max_ids_per_partition=stacked_table_spec.max_ids_per_partition, + max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition, + required_buffer_size_per_device=buffer_size, + ) + + return stats + + +def update_stacked_table_stats( + feature_specs: Nested[FeatureSpec], + stats: Mapping[str, InputStatsPerTable], ) -> None: - """Updates properties in the supplied feature specs. + """Updates stacked-table input properties in the supplied feature specs. Args: feature_specs: Feature specs to update in-place. - max_ids_per_partition: Mapping of table stack name to - new `max_ids_per_partition` for the stack. - max_unique_ids_per_partition: Mapping of table stack name to - new `max_unique_ids_per_partition` for the stack. + stats: Per-stacked-table input statistics. """ # Collect table specs and stacked table specs. table_specs: dict[str, TableSpec] = {} @@ -363,18 +397,17 @@ def update_stacked_table_specs( stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec # Replace fields in the stacked_table_specs. - stacked_table_specs = { - stack_name: dataclasses.replace( + stack_names = stacked_table_specs.keys() + for stack_name in stack_names: + stack_stats = stats[stack_name] + stacked_table_spec = stacked_table_specs[stack_name] + buffer_size = stack_stats.required_buffer_size_per_device or None + stacked_table_specs[stack_name] = dataclasses.replace( stacked_table_spec, - max_ids_per_partition=max_ids_per_partition[ - stacked_table_spec.stack_name - ], - max_unique_ids_per_partition=max_unique_ids_per_partition[ - stacked_table_spec.stack_name - ], + max_ids_per_partition=stack_stats.max_ids_per_partition, + max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition, + suggested_coo_buffer_size_per_device=buffer_size, ) - for stack_name, stacked_table_spec in stacked_table_specs.items() - } # Insert new stacked tables into tables. for table_spec in table_specs.values(): @@ -534,7 +567,7 @@ def stack_and_shard_samples( global_device_count: int, num_sc_per_device: int, static_buffer_size: int | Mapping[str, int] | None = None, -) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]: +) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]: """Prepares input samples for use in embedding lookups. Args: @@ -544,8 +577,8 @@ def stack_and_shard_samples( global_device_count: Number of global JAX devices. num_sc_per_device: Number of sparsecores per device. static_buffer_size: The static buffer size to use for the samples. - Defaults to None, in which case an upper-bound for the buffer size - will be automatically determined. + Defaults to None, in which case an upper-bound for the buffer size + will be automatically determined. Returns: The preprocessed inputs, and statistics useful for updating FeatureSpecs @@ -579,6 +612,7 @@ def collect_tokens_and_weights( ) out: dict[str, ShardedCooMatrix] = {} + out_stats: dict[str, InputStatsPerTable] = {} tables_names = preprocessed_inputs.lhs_row_pointers.keys() for table_name in tables_names: shard_ends = preprocessed_inputs.lhs_row_pointers[table_name] @@ -592,5 +626,17 @@ def collect_tokens_and_weights( row_ids=preprocessed_inputs.lhs_sample_ids[table_name], values=preprocessed_inputs.lhs_gains[table_name], ) + out_stats[table_name] = InputStatsPerTable( + max_ids_per_partition=np.max( + stats.max_ids_per_partition[table_name] + ), + max_unique_ids_per_partition=np.max( + stats.max_unique_ids_per_partition[table_name] + ), + required_buffer_size_per_device=np.max( + stats.required_buffer_size_per_sc[table_name] + ) + * num_sc_per_device, + ) - return out, stats + return out, out_stats From ed0372761c16965de3d20e4609ed9c04b1781764 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 10 Sep 2025 11:09:37 +0530 Subject: [PATCH 040/279] Change dataloader to global bsz --- examples/ml_perf/main.py | 2 +- examples/ml_perf/run.sh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8dbb38cb..3f661a8c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -115,7 +115,7 @@ def main( print("===== Loading dataset =====") train_ds = DataLoader( file_pattern=file_pattern, - batch_size=per_host_batch_size, + batch_size=global_batch_size, file_batch_size=file_batch_size, dense_features=dense_features, large_emb_features=large_emb_features, diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index 122dd92c..7a774221 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -101,7 +101,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="sudo apt-get update && sudo apt install -y python3.10-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" + --command="sudo apt-get update && sudo apt install -y python3.12-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3.12 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" # ============================================================================== @@ -155,7 +155,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python script.py" + --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python3.12 script.py" # ============================================================================== @@ -166,6 +166,6 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" + --command="source .keras-env/bin/activate && cd keras-rs && python3.12 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" echo ">>> Script finished." From 69023030f0fa2e92028a72238991e191885e0071 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 15 Oct 2025 17:28:22 +0530 Subject: [PATCH 041/279] Revert all non-example changes --- .../layers/embedding/jax/embedding_utils.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 1205fbb6..8a2d1cd3 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -34,12 +34,6 @@ class ShardedCooMatrix(NamedTuple): values: ArrayLike -class InputStatsPerTable(NamedTuple): - max_ids_per_partition: int - max_unique_ids_per_partition: int - required_buffer_size_per_device: int - - def _round_up_to_multiple(value: int, multiple: int) -> int: return ((value + multiple - 1) // multiple) * multiple @@ -479,7 +473,7 @@ def stack_and_shard_samples( global_device_count: int, num_sc_per_device: int, static_buffer_size: int | Mapping[str, int] | None = None, -) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]: +) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]: """Prepares input samples for use in embedding lookups. Args: @@ -524,7 +518,6 @@ def collect_tokens_and_weights( ) out: dict[str, ShardedCooMatrix] = {} - out_stats: dict[str, InputStatsPerTable] = {} tables_names = preprocessed_inputs.lhs_row_pointers.keys() for table_name in tables_names: shard_ends = preprocessed_inputs.lhs_row_pointers[table_name] @@ -538,17 +531,5 @@ def collect_tokens_and_weights( row_ids=preprocessed_inputs.lhs_sample_ids[table_name], values=preprocessed_inputs.lhs_gains[table_name], ) - out_stats[table_name] = InputStatsPerTable( - max_ids_per_partition=np.max( - stats.max_ids_per_partition[table_name] - ), - max_unique_ids_per_partition=np.max( - stats.max_unique_ids_per_partition[table_name] - ), - required_buffer_size_per_device=np.max( - stats.required_buffer_size_per_sc[table_name] - ) - * num_sc_per_device, - ) - return out, out_stats + return out, stats From 51183257570a532645d3493003e8a4584517139c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 18 Oct 2025 23:02:55 +0530 Subject: [PATCH 042/279] Address some comments + move to step-based trainer --- examples/ml_perf/configs/training/default_training.py | 6 ++++-- examples/ml_perf/configs/v6e_8.py | 1 + examples/ml_perf/dataloader.py | 2 +- examples/ml_perf/main.py | 9 +++++---- examples/ml_perf/run.sh | 10 ++++------ 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/ml_perf/configs/training/default_training.py b/examples/ml_perf/configs/training/default_training.py index b758bc59..985d1c3c 100644 --- a/examples/ml_perf/configs/training/default_training.py +++ b/examples/ml_perf/configs/training/default_training.py @@ -2,6 +2,8 @@ # === Training Hyperparameters === training_config = Config() -training_config.learning_rate = 0.005 +training_config.learning_rate = 0.0034 training_config.global_batch_size = 128 -training_config.num_epochs = 1 +# Set num_steps in the main config file instead of num_epochs, because we are +#using a Python generator. +# training_config.num_epochs = 1 diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index fcd81e39..ce77b9c4 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -12,5 +12,6 @@ config.dataset = dataset_config config.model = model_config config.training = training_config +config.training.num_steps = 2 config.freeze() diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index ce2e7286..5a501dda 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -91,7 +91,7 @@ def _create_dummy_dataset(self): labels = dummy_data.pop("clicked") features = dummy_data - dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) + dataset = tf.data.Dataset.from_tensors((features, labels)).repeat() return dataset def _get_feature_spec(self): diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 3f661a8c..90303f93 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -12,6 +12,7 @@ from .model import DLRMDCNV2 SEED = 1337 +keras.config.set_random_seed(SEED) def main( @@ -34,7 +35,7 @@ def main( learning_rate, global_batch_size, file_batch_size, - num_epochs, + num_steps, ): # Set DDP as Keras distribution strategy devices = keras.distribution.list_devices(device_type="tpu") @@ -163,7 +164,7 @@ def generator(dataset, training=False): break # Train the model. - model.fit(train_generator, epochs=num_epochs) + model.fit(train_generator, steps_per_epoch=num_steps) if __name__ == "__main__": @@ -221,7 +222,7 @@ def generator(dataset, training=False): training_cfg = config["training"] learning_rate = training_cfg["learning_rate"] global_batch_size = training_cfg["global_batch_size"] - num_epochs = training_cfg["num_epochs"] + num_steps = training_cfg["num_steps"] # For features which have vocabulary_size < embedding_threshold, we can # just do a normal dense lookup for those instead of having distributed @@ -257,5 +258,5 @@ def generator(dataset, training=False): learning_rate, global_batch_size, file_batch_size, - num_epochs, + num_steps, ) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index 7a774221..60d9a639 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -55,7 +55,7 @@ fi # ============================================================================== # Environment Variables # ============================================================================== -export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" +export TPU_NAME="${USER}-mlperf-${ACCELERATOR_TYPE}" export ZONE export PROJECT @@ -130,9 +130,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ echo '>>> Installing/updating dependencies...' pip install -e . - pip uninstall -y tensorflow keras - pip install git+https://github.com/keras-team/keras.git - pip install jax-tpu-embedding tensorflow-cpu + pip install -U jax-tpu-embedding tensorflow-cpu keras " @@ -144,7 +142,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && pip uninstall -y jax jaxlib && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" + --command="source .keras-env/bin/activate && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" # ============================================================================== @@ -155,7 +153,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python3.12 script.py" + --command="source .keras-env/bin/activate && python3.12 -c 'import jax; print(jax.devices())'" # ============================================================================== From 59c9d086213647732a23eba390c8969169ef181c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 18 Oct 2025 23:11:02 +0530 Subject: [PATCH 043/279] Fix --- examples/ml_perf/configs/training/default_training.py | 2 +- examples/ml_perf/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/configs/training/default_training.py b/examples/ml_perf/configs/training/default_training.py index 985d1c3c..d348e13b 100644 --- a/examples/ml_perf/configs/training/default_training.py +++ b/examples/ml_perf/configs/training/default_training.py @@ -5,5 +5,5 @@ training_config.learning_rate = 0.0034 training_config.global_batch_size = 128 # Set num_steps in the main config file instead of num_epochs, because we are -#using a Python generator. +# using a Python generator. # training_config.num_epochs = 1 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 90303f93..58604c01 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -12,7 +12,7 @@ from .model import DLRMDCNV2 SEED = 1337 -keras.config.set_random_seed(SEED) +keras.utils.set_random_seed(SEED) def main( From af9ba920b580bdd5a4524da06cd116df64e3c6ec Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 18 Oct 2025 23:37:05 +0530 Subject: [PATCH 044/279] Refactor configs to one file --- examples/ml_perf/configs/datasets/__init__.py | 0 .../ml_perf/configs/datasets/dummy_dataset.py | 166 -------------- examples/ml_perf/configs/models/__init__.py | 0 .../ml_perf/configs/models/default_model.py | 19 -- examples/ml_perf/configs/training/__init__.py | 0 .../configs/training/default_training.py | 9 - examples/ml_perf/configs/v6e_16.py | 197 ++++++++++++++++- examples/ml_perf/configs/v6e_8.py | 198 ++++++++++++++++- .../ml_perf/configs/v6e_8_full_dataset.py | 207 +++++++++++++++++- 9 files changed, 584 insertions(+), 212 deletions(-) delete mode 100644 examples/ml_perf/configs/datasets/__init__.py delete mode 100644 examples/ml_perf/configs/datasets/dummy_dataset.py delete mode 100644 examples/ml_perf/configs/models/__init__.py delete mode 100644 examples/ml_perf/configs/models/default_model.py delete mode 100644 examples/ml_perf/configs/training/__init__.py delete mode 100644 examples/ml_perf/configs/training/default_training.py diff --git a/examples/ml_perf/configs/datasets/__init__.py b/examples/ml_perf/configs/datasets/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py deleted file mode 100644 index aac66c57..00000000 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ /dev/null @@ -1,166 +0,0 @@ -from keras.utils import Config - -# === Dataset === -dataset_config = Config() -dataset_config.file_pattern = None -# Features -dataset_config.label = "clicked" -dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] -dataset_config.lookup = [ - { - "name": "categorical-feature-14", - "vocabulary_size": 40000000, - "feature_list_length": 3, - "new_name": "cat_14", - }, - { - "name": "categorical-feature-15", - "vocabulary_size": 39060, - "feature_list_length": 2, - "new_name": "cat_15", - }, - { - "name": "categorical-feature-16", - "vocabulary_size": 17295, - "feature_list_length": 1, - "new_name": "cat_16", - }, - { - "name": "categorical-feature-17", - "vocabulary_size": 7424, - "feature_list_length": 2, - "new_name": "cat_17", - }, - { - "name": "categorical-feature-18", - "vocabulary_size": 20265, - "feature_list_length": 6, - "new_name": "cat_18", - }, - { - "name": "categorical-feature-19", - "vocabulary_size": 3, - "feature_list_length": 1, - "new_name": "cat_19", - }, - { - "name": "categorical-feature-20", - "vocabulary_size": 7122, - "feature_list_length": 1, - "new_name": "cat_20", - }, - { - "name": "categorical-feature-21", - "vocabulary_size": 1543, - "feature_list_length": 1, - "new_name": "cat_21", - }, - { - "name": "categorical-feature-22", - "vocabulary_size": 63, - "feature_list_length": 1, - "new_name": "cat_22", - }, - { - "name": "categorical-feature-23", - "vocabulary_size": 40000000, - "feature_list_length": 7, - "new_name": "cat_23", - }, - { - "name": "categorical-feature-24", - "vocabulary_size": 3067956, - "feature_list_length": 3, - "new_name": "cat_24", - }, - { - "name": "categorical-feature-25", - "vocabulary_size": 405282, - "feature_list_length": 8, - "new_name": "cat_25", - }, - { - "name": "categorical-feature-26", - "vocabulary_size": 10, - "feature_list_length": 1, - "new_name": "cat_26", - }, - { - "name": "categorical-feature-27", - "vocabulary_size": 2209, - "feature_list_length": 6, - "new_name": "cat_27", - }, - { - "name": "categorical-feature-28", - "vocabulary_size": 11938, - "feature_list_length": 9, - "new_name": "cat_28", - }, - { - "name": "categorical-feature-29", - "vocabulary_size": 155, - "feature_list_length": 5, - "new_name": "cat_29", - }, - { - "name": "categorical-feature-30", - "vocabulary_size": 4, - "feature_list_length": 1, - "new_name": "cat_30", - }, - { - "name": "categorical-feature-31", - "vocabulary_size": 976, - "feature_list_length": 1, - "new_name": "cat_31", - }, - { - "name": "categorical-feature-32", - "vocabulary_size": 14, - "feature_list_length": 1, - "new_name": "cat_32", - }, - { - "name": "categorical-feature-33", - "vocabulary_size": 40000000, - "feature_list_length": 12, - "new_name": "cat_33", - }, - { - "name": "categorical-feature-34", - "vocabulary_size": 40000000, - "feature_list_length": 100, - "new_name": "cat_34", - }, - { - "name": "categorical-feature-35", - "vocabulary_size": 40000000, - "feature_list_length": 27, - "new_name": "cat_35", - }, - { - "name": "categorical-feature-36", - "vocabulary_size": 590152, - "feature_list_length": 10, - "new_name": "cat_36", - }, - { - "name": "categorical-feature-37", - "vocabulary_size": 12973, - "feature_list_length": 3, - "new_name": "cat_37", - }, - { - "name": "categorical-feature-38", - "vocabulary_size": 108, - "feature_list_length": 1, - "new_name": "cat_38", - }, - { - "name": "categorical-feature-39", - "vocabulary_size": 36, - "feature_list_length": 1, - "new_name": "cat_39", - }, -] diff --git a/examples/ml_perf/configs/models/__init__.py b/examples/ml_perf/configs/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/ml_perf/configs/models/default_model.py b/examples/ml_perf/configs/models/default_model.py deleted file mode 100644 index 1a07e9d9..00000000 --- a/examples/ml_perf/configs/models/default_model.py +++ /dev/null @@ -1,19 +0,0 @@ -from keras.utils import Config - -# === Model === -model_config = Config() -# Embedding -model_config.embedding_dim = 128 -model_config.allow_id_dropping = True -model_config.embedding_threshold = 21000 -model_config.max_ids_per_partition = 4096 -model_config.max_unique_ids_per_partition = 2048 -model_config.learning_rate = 0.005 - -# MLP -model_config.bottom_mlp_dims = [512, 256, 128] -model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] - -# DCN -model_config.num_dcn_layers = 3 -model_config.dcn_projection_dim = 512 diff --git a/examples/ml_perf/configs/training/__init__.py b/examples/ml_perf/configs/training/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/ml_perf/configs/training/default_training.py b/examples/ml_perf/configs/training/default_training.py deleted file mode 100644 index d348e13b..00000000 --- a/examples/ml_perf/configs/training/default_training.py +++ /dev/null @@ -1,9 +0,0 @@ -from keras.utils import Config - -# === Training Hyperparameters === -training_config = Config() -training_config.learning_rate = 0.0034 -training_config.global_batch_size = 128 -# Set num_steps in the main config file instead of num_epochs, because we are -# using a Python generator. -# training_config.num_epochs = 1 diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 4b6df8df..134076e8 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -1,14 +1,203 @@ from keras.utils import Config -from .datasets.dummy_dataset import dataset_config -from .models.default_model import model_config -from .training.default_training import training_config - config = Config() +# === Experiment metadata === config.experiment_name = "v6e_16" config.model_dir = "./v6e_16" +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = None +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] +dataset_config.lookup = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "feature_list_length": 3, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "feature_list_length": 1, + "new_name": "cat_16", + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "cat_18", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "cat_19", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "cat_20", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "cat_21", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "cat_22", + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "cat_25", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_26", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_27", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_28", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_29", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_30", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_31", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_32", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "cat_33", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "cat_34", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "cat_35", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "cat_36", + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "cat_37", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "cat_38", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "cat_39", + }, +] + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 8192 +model_config.max_unique_ids_per_partition = 4096 +model_config.learning_rate = 0.0034 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 + +# === Training === +training_config = Config() +training_config.learning_rate = 0.0034 +training_config.global_batch_size = 128 +# Set `num_steps` in the main config file instead of num_epochs, because we are +# using a Python generator. +training_config.num_steps = 2 + +# === Assign all configs to the root config === config.dataset = dataset_config config.model = model_config config.training = training_config diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index ce77b9c4..4f30552c 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -1,17 +1,205 @@ from keras.utils import Config -from .datasets.dummy_dataset import dataset_config -from .models.default_model import model_config -from .training.default_training import training_config - config = Config() +# === Experiment metadata === config.experiment_name = "v6e_8" config.model_dir = "./v6e_8" +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = None +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] +dataset_config.lookup = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "feature_list_length": 3, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "feature_list_length": 1, + "new_name": "cat_16", + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "cat_18", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "cat_19", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "cat_20", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "cat_21", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "cat_22", + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "cat_25", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_26", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_27", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_28", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_29", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_30", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_31", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_32", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "cat_33", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "cat_34", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "cat_35", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "cat_36", + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "cat_37", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "cat_38", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "cat_39", + }, +] + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 8192 +model_config.max_unique_ids_per_partition = 4096 +model_config.learning_rate = 0.0034 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 + +# === Training === +training_config = Config() +training_config.learning_rate = 0.0034 +training_config.global_batch_size = 128 +# Set `num_steps` in the main config file instead of num_epochs, because we are +# using a Python generator. +training_config.num_steps = 2 + +# === Assign all configs to the root config === config.dataset = dataset_config config.model = model_config config.training = training_config -config.training.num_steps = 2 config.freeze() diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index 8489b084..30a59f97 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -1,27 +1,216 @@ from keras.utils import Config -from .datasets.dummy_dataset import dataset_config -from .models.default_model import model_config -from .training.default_training import training_config - config = Config() +# === Experiment metadata === config.experiment_name = "v6e_8_full_dataset" config.model_dir = "./v6e_8_full_dataset" -config.dataset = dataset_config -config.dataset.file_pattern = ( +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = ( "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-00000-of-01024tfrecord" ) -config.dataset.val_file_pattern = ( +dataset_config.val_file_pattern = ( "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-00000-of-01024tfrecord" ) # The path which we are reading from already has the batched dataset. -config.dataset.file_batch_size = 4224 +dataset_config.file_batch_size = 4224 + +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] +dataset_config.lookup = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "feature_list_length": 3, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "feature_list_length": 1, + "new_name": "cat_16", + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "cat_18", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "cat_19", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "cat_20", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "cat_21", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "cat_22", + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "cat_25", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_26", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_27", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_28", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_29", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_30", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_31", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_32", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "cat_33", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "cat_34", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "cat_35", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "cat_36", + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "cat_37", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "cat_38", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "cat_39", + }, +] + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 8192 +model_config.max_unique_ids_per_partition = 4096 +model_config.learning_rate = 0.0034 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 + +# === Training === +training_config = Config() +training_config.learning_rate = 0.0034 +training_config.global_batch_size = 128 +training_config.batch_size = 256 +# Set `num_steps` instead of `num_epochs`, because we are using a Python +# generator. +training_config.num_steps = 2 + +# === Assign all configs to the root config === +config.dataset = dataset_config config.model = model_config config.training = training_config -config.training.batch_size = 256 config.freeze() From 8b6e300c34d8f458d7a28b43c41f251d03d5ba48 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 19 Oct 2025 08:28:00 +0530 Subject: [PATCH 045/279] Follow the original example in specifying input/output shape --- examples/ml_perf/configs/v6e_8.py | 1 + examples/ml_perf/main.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index 4f30552c..56b7a4ca 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -196,6 +196,7 @@ # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 2 +training_config.eval_freq = 1 # === Assign all configs to the root config === config.dataset = dataset_config diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 58604c01..d15e35a1 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -84,9 +84,9 @@ def main( name=feature_name, table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or - # `(bsz, feature_list_length)`. - input_shape=(per_host_batch_size, feature_list_length), - output_shape=(per_host_batch_size, embedding_dim), + # `(bsz, feature_list_length)`. The original example uses 1. + input_shape=(global_batch_size, 1), + output_shape=(global_batch_size, embedding_dim), ) # === Instantiate model === From 760639ae4c5c520d28ad1dfcd71809686fe77683 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 19 Oct 2025 13:14:24 +0530 Subject: [PATCH 046/279] Add debugging statements --- examples/ml_perf/run.sh | 16 ++++++++++++++++ .../embedding/jax/distributed_embedding.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index 60d9a639..30548de9 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -144,6 +144,22 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --worker=all \ --command="source .keras-env/bin/activate && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" +# ============================================================================== +# Kill Previous Training Processes +# ============================================================================== +# echo ">>> Listing matching processes..." +# gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ +# --project ${PROJECT} \ +# --zone ${ZONE} \ +# --worker=all \ +# --command="ps aux | grep '[e]xamples.ml_perf.main' || true" + +# echo ">>> Terminating any existing training processes..." +# gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ +# --project ${PROJECT} \ +# --zone ${ZONE} \ +# --worker=all \ +# --command="pkill -9 -f 'python3.12 -m examples.ml_perf.[m]ain.*' || true" # ============================================================================== # Verify Installation diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1f5916fa..892ddb0c 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -580,12 +580,17 @@ def _sparsecore_preprocess( ) layout = self._sparsecore_layout + print(f"-->{layout=}") mesh = layout.device_mesh.backend_mesh + print(f"-->{mesh=}") global_device_count = mesh.devices.size + print(f"-->{global_device_count=}") local_device_count = mesh.local_mesh.devices.size + print(f"{local_device_count=}") num_sc_per_device = jte_utils.num_sparsecores_per_device( mesh.devices.item(0) ) + print(f"-->{num_sc_per_device=}") preprocessed, stats = embedding_utils.stack_and_shard_samples( self._config.feature_specs, @@ -594,6 +599,7 @@ def _sparsecore_preprocess( global_device_count, num_sc_per_device, ) + print(f"-->{stats=}") if training: # Synchronize input statistics across all devices and update the @@ -601,6 +607,7 @@ def _sparsecore_preprocess( # Aggregate stats across all processes/devices via pmax. num_local_cpu_devices = jax.local_device_count("cpu") + print(f"-->{num_local_cpu_devices=}") def pmax_aggregate(x: Any) -> Any: if not hasattr(x, "ndim"): From deaf35aea0cdf5728c045d9054f7da62dc4ec379 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 19 Oct 2025 15:22:30 +0530 Subject: [PATCH 047/279] Add debugging statements --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 892ddb0c..a16c5f38 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -591,6 +591,7 @@ def _sparsecore_preprocess( mesh.devices.item(0) ) print(f"-->{num_sc_per_device=}") + print(f"-->{jax.process_count()=}") preprocessed, stats = embedding_utils.stack_and_shard_samples( self._config.feature_specs, @@ -612,6 +613,7 @@ def _sparsecore_preprocess( def pmax_aggregate(x: Any) -> Any: if not hasattr(x, "ndim"): x = np.array(x) + jax.debug.print("--> x.shape={}", x.shape) tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) return jax.pmap( lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] From 3d0b640a56cf2aec3b30a267b1dd156bfd0c8570 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 19 Oct 2025 15:27:16 +0530 Subject: [PATCH 048/279] Add debugging statements --- keras_rs/.DS_Store | Bin 0 -> 6148 bytes .../embedding/jax/distributed_embedding.py | 1 + 2 files changed, 1 insertion(+) create mode 100644 keras_rs/.DS_Store diff --git a/keras_rs/.DS_Store b/keras_rs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5172429f264de2441865cb4700216d4256da9242 GIT binary patch literal 6148 zcmeH~J!%6%427R!7lt%jx}3%b$PET#pTHLgIFQEJ;E>dF^gR7ES*H$5cmnB-G%I%Z zD|S`@Z2$T80!#olbXV*=%*>dt@PRwdU#I)^a=X5>;#J@&VrHyNnC;iLL0pQvfVyTmjO&;ssLc!1UOG})p;=82 zR;?Ceh}WZ?+UmMqI#RP8R>OzYoz15hnq@nzF`-!xQ4j$Um=RcIKKc27r2jVm&svm< zfC&6E0=7P!4tu^-ovjbA=k?dB`g+i*aXG_}p8zI)6mRKa+;6_1_R^8c3Qa!(fk8n8 H{*=HsM+*^= literal 0 HcmV?d00001 diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index a16c5f38..2987e47b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -615,6 +615,7 @@ def pmax_aggregate(x: Any) -> Any: x = np.array(x) jax.debug.print("--> x.shape={}", x.shape) tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) + jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) return jax.pmap( lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] axis_name="all_cpus", From 5782407995075865268a6e349924d6edd5689ad9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 19 Oct 2025 15:27:44 +0530 Subject: [PATCH 049/279] Add debugging statements --- keras_rs/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 keras_rs/.DS_Store diff --git a/keras_rs/.DS_Store b/keras_rs/.DS_Store deleted file mode 100644 index 5172429f264de2441865cb4700216d4256da9242..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~J!%6%427R!7lt%jx}3%b$PET#pTHLgIFQEJ;E>dF^gR7ES*H$5cmnB-G%I%Z zD|S`@Z2$T80!#olbXV*=%*>dt@PRwdU#I)^a=X5>;#J@&VrHyNnC;iLL0pQvfVyTmjO&;ssLc!1UOG})p;=82 zR;?Ceh}WZ?+UmMqI#RP8R>OzYoz15hnq@nzF`-!xQ4j$Um=RcIKKc27r2jVm&svm< zfC&6E0=7P!4tu^-ovjbA=k?dB`g+i*aXG_}p8zI)6mRKa+;6_1_R^8c3Qa!(fk8n8 H{*=HsM+*^= From 94ffb7e0973f901822aad68b6fc92f573553918f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 20 Oct 2025 23:50:10 +0530 Subject: [PATCH 050/279] Temp comment out stats update code --- .../embedding/jax/distributed_embedding.py | 106 +++++++++--------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 2987e47b..1697a497 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -602,59 +602,59 @@ def _sparsecore_preprocess( ) print(f"-->{stats=}") - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - num_local_cpu_devices = jax.local_device_count("cpu") - print(f"-->{num_local_cpu_devices=}") - - def pmax_aggregate(x: Any) -> Any: - if not hasattr(x, "ndim"): - x = np.array(x) - jax.debug.print("--> x.shape={}", x.shape) - tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) - jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) - return jax.pmap( - lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] - axis_name="all_cpus", - backend="cpu", - )(tiled_x)[0] - - full_stats = jax.tree.map(pmax_aggregate, stats) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(full_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max(full_stats.max_unique_ids_per_partition[stack_name]) - > spec.max_unique_ids_per_partition - or ( - np.max(full_stats.required_buffer_size_per_sc[stack_name]) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, full_stats, num_sc_per_device - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # num_local_cpu_devices = jax.local_device_count("cpu") + # print(f"-->{num_local_cpu_devices=}") + + # def pmax_aggregate(x: Any) -> Any: + # if not hasattr(x, "ndim"): + # x = np.array(x) + # jax.debug.print("--> x.shape={}", x.shape) + # tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) + # jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) + # return jax.pmap( + # lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] + # axis_name="all_cpus", + # backend="cpu", + # )(tiled_x)[0] + + # full_stats = jax.tree.map(pmax_aggregate, stats) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(full_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max(full_stats.max_unique_ids_per_partition[stack_name]) + # > spec.max_unique_ids_per_partition + # or ( + # np.max(full_stats.required_buffer_size_per_sc[stack_name]) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, full_stats, num_sc_per_device + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From c578ff8cd81e8fc49fa1647803c1e961b74c2a80 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 00:24:23 +0530 Subject: [PATCH 051/279] Change input size to dist emb --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d15e35a1..bf567e1b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -85,8 +85,8 @@ def main( table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or # `(bsz, feature_list_length)`. The original example uses 1. - input_shape=(global_batch_size, 1), - output_shape=(global_batch_size, embedding_dim), + input_shape=(per_host_batch_size, 1), + output_shape=(per_host_batch_size, embedding_dim), ) # === Instantiate model === From b3425f906775ab766706a3a73e9bc1850888ab0f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 02:15:01 +0530 Subject: [PATCH 052/279] Bsz --- examples/ml_perf/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index bf567e1b..a0ab9e42 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -85,8 +85,8 @@ def main( table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or # `(bsz, feature_list_length)`. The original example uses 1. - input_shape=(per_host_batch_size, 1), - output_shape=(per_host_batch_size, embedding_dim), + input_shape=(global_batch_size, 1), + output_shape=(global_batch_size, embedding_dim), ) # === Instantiate model === @@ -154,6 +154,7 @@ def generator(dataset, training=False): "large_emb_inputs": preprocessed_large_embeddings, "small_emb_inputs": features["small_emb_inputs"], } + print(x["large_emb_inputs"]) y = labels yield (x, y) From f2a849add26240b194d808e2c124bc8089bc3640 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 02:18:20 +0530 Subject: [PATCH 053/279] Bsz --- examples/ml_perf/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a0ab9e42..e60e3953 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -154,18 +154,18 @@ def generator(dataset, training=False): "large_emb_inputs": preprocessed_large_embeddings, "small_emb_inputs": features["small_emb_inputs"], } - print(x["large_emb_inputs"]) y = labels yield (x, y) train_generator = generator(train_ds, training=True) # eval_generator = generator(eval_ds, training=False) for first_batch in train_generator: - model(first_batch[0]) + print(first_batch[0]) + # model(first_batch[0]) break # Train the model. - model.fit(train_generator, steps_per_epoch=num_steps) + # model.fit(train_generator, steps_per_epoch=num_steps) if __name__ == "__main__": From d08479e1483b7197f3d724e05319c1fc97bbf1c2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 02:21:04 +0530 Subject: [PATCH 054/279] Bsz --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index e60e3953..e7ce8bcb 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -137,8 +137,8 @@ def main( distribution.auto_shard_dataset = False # Print one sample. - for element in train_ds.take(1): - print(">>> train sample", element[0]) + # for element in train_ds.take(1): + # print(">>> train sample", element[0]) def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses From 604c07bb6eebf14e4ff6ffa4731ed99dc2aaea7a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 11:13:50 +0530 Subject: [PATCH 055/279] Restore stats stuff --- .../embedding/jax/distributed_embedding.py | 106 +++++++++--------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1697a497..2987e47b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -602,59 +602,59 @@ def _sparsecore_preprocess( ) print(f"-->{stats=}") - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # num_local_cpu_devices = jax.local_device_count("cpu") - # print(f"-->{num_local_cpu_devices=}") - - # def pmax_aggregate(x: Any) -> Any: - # if not hasattr(x, "ndim"): - # x = np.array(x) - # jax.debug.print("--> x.shape={}", x.shape) - # tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) - # jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) - # return jax.pmap( - # lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] - # axis_name="all_cpus", - # backend="cpu", - # )(tiled_x)[0] - - # full_stats = jax.tree.map(pmax_aggregate, stats) - - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(full_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max(full_stats.max_unique_ids_per_partition[stack_name]) - # > spec.max_unique_ids_per_partition - # or ( - # np.max(full_stats.required_buffer_size_per_sc[stack_name]) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, full_stats, num_sc_per_device - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + num_local_cpu_devices = jax.local_device_count("cpu") + print(f"-->{num_local_cpu_devices=}") + + def pmax_aggregate(x: Any) -> Any: + if not hasattr(x, "ndim"): + x = np.array(x) + jax.debug.print("--> x.shape={}", x.shape) + tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) + jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) + return jax.pmap( + lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] + axis_name="all_cpus", + backend="cpu", + )(tiled_x)[0] + + full_stats = jax.tree.map(pmax_aggregate, stats) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(full_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max(full_stats.max_unique_ids_per_partition[stack_name]) + > spec.max_unique_ids_per_partition + or ( + np.max(full_stats.required_buffer_size_per_sc[stack_name]) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # Update configuration and repeat preprocessing if stats changed. + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, full_stats, num_sc_per_device + ) + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 1d5e983ed2bdb8055f49d1f195fa9df46c1fca79 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 13:14:39 +0530 Subject: [PATCH 056/279] Comment out stats update for now --- .../embedding/jax/distributed_embedding.py | 106 +++++++++--------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 2987e47b..1697a497 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -602,59 +602,59 @@ def _sparsecore_preprocess( ) print(f"-->{stats=}") - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - num_local_cpu_devices = jax.local_device_count("cpu") - print(f"-->{num_local_cpu_devices=}") - - def pmax_aggregate(x: Any) -> Any: - if not hasattr(x, "ndim"): - x = np.array(x) - jax.debug.print("--> x.shape={}", x.shape) - tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) - jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) - return jax.pmap( - lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] - axis_name="all_cpus", - backend="cpu", - )(tiled_x)[0] - - full_stats = jax.tree.map(pmax_aggregate, stats) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(full_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max(full_stats.max_unique_ids_per_partition[stack_name]) - > spec.max_unique_ids_per_partition - or ( - np.max(full_stats.required_buffer_size_per_sc[stack_name]) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, full_stats, num_sc_per_device - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # num_local_cpu_devices = jax.local_device_count("cpu") + # print(f"-->{num_local_cpu_devices=}") + + # def pmax_aggregate(x: Any) -> Any: + # if not hasattr(x, "ndim"): + # x = np.array(x) + # jax.debug.print("--> x.shape={}", x.shape) + # tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) + # jax.debug.print("--> tiled_x.shape={}", tiled_x.shape) + # return jax.pmap( + # lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] + # axis_name="all_cpus", + # backend="cpu", + # )(tiled_x)[0] + + # full_stats = jax.tree.map(pmax_aggregate, stats) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(full_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max(full_stats.max_unique_ids_per_partition[stack_name]) + # > spec.max_unique_ids_per_partition + # or ( + # np.max(full_stats.required_buffer_size_per_sc[stack_name]) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, full_stats, num_sc_per_device + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 832605e65fe4ead4a9ced18cbcfb8814938357b1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 14:28:34 +0530 Subject: [PATCH 057/279] Try alternate way of sharding dataset --- examples/ml_perf/dataloader.py | 4 +++- examples/ml_perf/main.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5a501dda..5d56995b 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -163,7 +163,9 @@ def _get_emb_inputs(emb_features): def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): if self._return_dummy_dataset: - return self._create_dummy_dataset() + dataset = self._create_dummy_dataset() + dataset = dataset.shard(num_processes, process_id) + return dataset dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index e7ce8bcb..d0148763 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -131,10 +131,10 @@ def main( # For the multi-host case, the dataset has to be distributed manually. # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. - if num_processes > 1: - train_ds = distribution.distribute_dataset(train_ds) - # eval_ds = distribution.distribute_dataset(eval_ds) - distribution.auto_shard_dataset = False + # if num_processes > 1: + # train_ds = distribution.distribute_dataset(train_ds) + # # eval_ds = distribution.distribute_dataset(eval_ds) + # distribution.auto_shard_dataset = False # Print one sample. # for element in train_ds.take(1): From e621af726a3628647b8974cc7745446ae2165520 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 14:31:20 +0530 Subject: [PATCH 058/279] Try alternate way of sharding dataset --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d0148763..69512125 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -165,7 +165,7 @@ def generator(dataset, training=False): break # Train the model. - # model.fit(train_generator, steps_per_epoch=num_steps) + model.fit(train_generator, steps_per_epoch=num_steps) if __name__ == "__main__": From 0b11227e0c6500826949744168bc6ff2b3b5607d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 14:33:49 +0530 Subject: [PATCH 059/279] Try alternate way of sharding dataset --- examples/ml_perf/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 69512125..34ff55a0 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -131,10 +131,10 @@ def main( # For the multi-host case, the dataset has to be distributed manually. # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. - # if num_processes > 1: - # train_ds = distribution.distribute_dataset(train_ds) - # # eval_ds = distribution.distribute_dataset(eval_ds) - # distribution.auto_shard_dataset = False + if num_processes > 1: + # train_ds = distribution.distribute_dataset(train_ds) + # eval_ds = distribution.distribute_dataset(eval_ds) + distribution.auto_shard_dataset = False # Print one sample. # for element in train_ds.take(1): From 136a57f66f176f0aaf2b1e3a2aa33a588d1b7530 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 16:42:57 +0530 Subject: [PATCH 060/279] Debug --- examples/ml_perf/dataloader.py | 22 +++++++------- examples/ml_perf/main.py | 52 +++++++++++++++++----------------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5d56995b..092fc310 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -26,6 +26,7 @@ def __init__( # Derived attributes. self._return_dummy_dataset = file_pattern is None + self._per_host_batch_size = self.batch_size // jax.process_count() def _get_dummy_batch(self): """Returns a dummy batch of data in the final desired structure.""" @@ -91,7 +92,8 @@ def _create_dummy_dataset(self): labels = dummy_data.pop("clicked") features = dummy_data - dataset = tf.data.Dataset.from_tensors((features, labels)).repeat() + dataset = tf.data.Dataset.from_tensors((features, labels)) + dataset = dataset.repeat() return dataset def _get_feature_spec(self): @@ -163,17 +165,17 @@ def _get_emb_inputs(emb_features): def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): if self._return_dummy_dataset: - dataset = self._create_dummy_dataset() - dataset = dataset.shard(num_processes, process_id) - return dataset + return self._create_dummy_dataset() dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - # Shard the dataset across hosts/workers. - # TODO: Do we need to do this if we are distributing the dataset - # manually using distribution.distribute_dataset(...)? - if num_processes > 1: - dataset = dataset.shard(num_processes, process_id) + # # Shard the dataset across hosts/workers. + # # TODO: Do we need to do this if we are distributing the dataset + # # manually using distribution.distribute_dataset(...)? + # This is not needed, because distribute_dataset shards the dataset + # across hosts. + # if num_processes > 1: + # dataset = dataset.shard(num_processes, process_id) dataset = tf.data.TFRecordDataset( dataset, @@ -192,7 +194,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.shuffle(shuffle_buffer) dataset = dataset.batch( - self.batch_size, + self._per_host_batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 34ff55a0..13d228be 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -132,40 +132,40 @@ def main( # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: - # train_ds = distribution.distribute_dataset(train_ds) + train_ds = distribution.distribute_dataset(train_ds) # eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False # Print one sample. - # for element in train_ds.take(1): - # print(">>> train sample", element[0]) + for element in train_ds.take(1): + print(">>> train sample", element[0]) - def generator(dataset, training=False): - """Converts tf.data Dataset to a Python generator and preprocesses - large embedding features. - """ - for features, labels in dataset: - preprocessed_large_embeddings = model.embedding_layer.preprocess( - features["large_emb_inputs"], training=training - ) + # def generator(dataset, training=False): + # """Converts tf.data Dataset to a Python generator and preprocesses + # large embedding features. + # """ + # for features, labels in dataset: + # preprocessed_large_embeddings = model.embedding_layer.preprocess( + # features["large_emb_inputs"], training=training + # ) - x = { - "dense_input": features["dense_input"], - "large_emb_inputs": preprocessed_large_embeddings, - "small_emb_inputs": features["small_emb_inputs"], - } - y = labels - yield (x, y) + # x = { + # "dense_input": features["dense_input"], + # "large_emb_inputs": preprocessed_large_embeddings, + # "small_emb_inputs": features["small_emb_inputs"], + # } + # y = labels + # yield (x, y) - train_generator = generator(train_ds, training=True) - # eval_generator = generator(eval_ds, training=False) - for first_batch in train_generator: - print(first_batch[0]) - # model(first_batch[0]) - break + # train_generator = generator(train_ds, training=True) + # # eval_generator = generator(eval_ds, training=False) + # for first_batch in train_generator: + # print(first_batch[0]) + # # model(first_batch[0]) + # break - # Train the model. - model.fit(train_generator, steps_per_epoch=num_steps) + # # Train the model. + # model.fit(train_generator, steps_per_epoch=num_steps) if __name__ == "__main__": From 2ae73ced6b05ec55ce2748a60890dab7dc8641db Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 16:45:12 +0530 Subject: [PATCH 061/279] Debug --- examples/ml_perf/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 092fc310..ddfefad7 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -26,7 +26,7 @@ def __init__( # Derived attributes. self._return_dummy_dataset = file_pattern is None - self._per_host_batch_size = self.batch_size // jax.process_count() + # self._per_host_batch_size = self.batch_size // jax.process_count() def _get_dummy_batch(self): """Returns a dummy batch of data in the final desired structure.""" @@ -194,7 +194,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.shuffle(shuffle_buffer) dataset = dataset.batch( - self._per_host_batch_size, + self.batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) From 17cb17b3bc5cb4dede1f957d5feb629a7ea79e0c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 16:49:35 +0530 Subject: [PATCH 062/279] Debug --- examples/ml_perf/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 13d228be..cb7df7b5 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -138,7 +138,8 @@ def main( # Print one sample. for element in train_ds.take(1): - print(">>> train sample", element[0]) + print(">>> train sample", element[0]["large_emb_inputs"]['cat_32_id']) + print(">>> train sample", element[0]["small_emb_inputs"]) # def generator(dataset, training=False): # """Converts tf.data Dataset to a Python generator and preprocesses From 37c21e391b74735833946d8dc6b74294e6ecb304 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 16:53:33 +0530 Subject: [PATCH 063/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index cb7df7b5..1075586b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -138,7 +138,7 @@ def main( # Print one sample. for element in train_ds.take(1): - print(">>> train sample", element[0]["large_emb_inputs"]['cat_32_id']) + print(">>> train sample", element[0]["large_emb_inputs"]['cat_33_id']) print(">>> train sample", element[0]["small_emb_inputs"]) # def generator(dataset, training=False): From ebcb9d0385bcb0f49923d6df3bbcb66dd0d2a382 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 17:03:07 +0530 Subject: [PATCH 064/279] Debug --- examples/ml_perf/main.py | 43 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 1075586b..dfdf214f 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -139,31 +139,32 @@ def main( # Print one sample. for element in train_ds.take(1): print(">>> train sample", element[0]["large_emb_inputs"]['cat_33_id']) - print(">>> train sample", element[0]["small_emb_inputs"]) + print(">>> train sample", element[0]["small_emb_inputs"]['cat_39_id']) - # def generator(dataset, training=False): - # """Converts tf.data Dataset to a Python generator and preprocesses - # large embedding features. - # """ - # for features, labels in dataset: - # preprocessed_large_embeddings = model.embedding_layer.preprocess( - # features["large_emb_inputs"], training=training - # ) + def generator(dataset, training=False): + """Converts tf.data Dataset to a Python generator and preprocesses + large embedding features. + """ + for features, labels in dataset: + preprocessed_large_embeddings = model.embedding_layer.preprocess( + features["large_emb_inputs"], training=training + ) - # x = { - # "dense_input": features["dense_input"], - # "large_emb_inputs": preprocessed_large_embeddings, - # "small_emb_inputs": features["small_emb_inputs"], - # } - # y = labels - # yield (x, y) + x = { + "dense_input": features["dense_input"], + "large_emb_inputs": preprocessed_large_embeddings, + "small_emb_inputs": features["small_emb_inputs"], + } + y = labels + yield (x, y) - # train_generator = generator(train_ds, training=True) + train_generator = generator(train_ds, training=True) # # eval_generator = generator(eval_ds, training=False) - # for first_batch in train_generator: - # print(first_batch[0]) - # # model(first_batch[0]) - # break + for first_batch in train_generator: + print(first_batch[0]["dense_input"]) + print(first_batch[0]["small_emb_inputs"]['cat_39_id']) + # model(first_batch[0]) + break # # Train the model. # model.fit(train_generator, steps_per_epoch=num_steps) From e8b86a5992a6ad603c3a31c0fb193eee2042b36a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 17:06:05 +0530 Subject: [PATCH 065/279] Debug --- examples/ml_perf/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index dfdf214f..4b31e210 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -161,8 +161,10 @@ def generator(dataset, training=False): train_generator = generator(train_ds, training=True) # # eval_generator = generator(eval_ds, training=False) for first_batch in train_generator: - print(first_batch[0]["dense_input"]) - print(first_batch[0]["small_emb_inputs"]['cat_39_id']) + print("------>dense", first_batch[0]["dense_input"]) + print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) + print("-------> large", first_batch[0]["large_emb_inputs"]) + # model(first_batch[0]) break From 726dab1c3e4032703545ee6cbef9e2d52f2da218 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 17:22:56 +0530 Subject: [PATCH 066/279] Debug --- keras_rs/src/layers/embedding/jax/embedding_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 80e342dd..e57f6f61 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -211,6 +211,7 @@ def collect_tokens_and_weights( jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples) + print("BATCH SIZE PER SC", feature_tokens.shape[0] // num_sc_per_device) preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input( feature_tokens, feature_weights, From e3af0e9e9a2ee052a946d90387aabd46a7c919d3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 17:28:40 +0530 Subject: [PATCH 067/279] Debug --- keras_rs/src/layers/embedding/jax/embedding_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index e57f6f61..65bc43ec 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -211,7 +211,7 @@ def collect_tokens_and_weights( jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples) - print("BATCH SIZE PER SC", feature_tokens.shape[0] // num_sc_per_device) + print("BATCH SIZE PER SC", feature_tokens[0].shape[0] // num_sc_per_device) preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input( feature_tokens, feature_weights, From c79fc2cfca28ae06c3a9bf91115a27eacc0a1bc8 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 18:12:15 +0530 Subject: [PATCH 068/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 4b31e210..d64a79d4 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -165,7 +165,7 @@ def generator(dataset, training=False): print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) print("-------> large", first_batch[0]["large_emb_inputs"]) - # model(first_batch[0]) + model(first_batch[0]) break # # Train the model. From 627dec873171080e76ecc2d241eec0d56fb79d8a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 18:16:23 +0530 Subject: [PATCH 069/279] Debug --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d64a79d4..fe6a8cea 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -165,11 +165,11 @@ def generator(dataset, training=False): print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) print("-------> large", first_batch[0]["large_emb_inputs"]) - model(first_batch[0]) + # model(first_batch[0]) break # # Train the model. - # model.fit(train_generator, steps_per_epoch=num_steps) + model.fit(train_generator, steps_per_epoch=num_steps) if __name__ == "__main__": From f4d255a1dbfc18f88239ca113ef10d4fe81c7cb2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 21 Oct 2025 18:22:55 +0530 Subject: [PATCH 070/279] Debug --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 134076e8..0a38413a 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -195,7 +195,7 @@ training_config.global_batch_size = 128 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. -training_config.num_steps = 2 +training_config.num_steps = 20 # === Assign all configs to the root config === config.dataset = dataset_config From e3bde1577b43031a81a062b9884e87587cefe63e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:20:11 +0530 Subject: [PATCH 071/279] Fix data loading and change bsz to 16K --- examples/ml_perf/configs/v6e_16.py | 2 +- examples/ml_perf/dataloader.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 0a38413a..3616008f 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 128 +training_config.global_batch_size = 16896 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index ddfefad7..1f4f170a 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -1,6 +1,7 @@ import numpy as np import tensorflow as tf +SEED = 1337 class DataLoader: def __init__( @@ -167,16 +168,10 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): if self._return_dummy_dataset: return self._create_dummy_dataset() + # Important to specify shuffle = False here to ensure all processes have + # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - # # Shard the dataset across hosts/workers. - # # TODO: Do we need to do this if we are distributing the dataset - # # manually using distribution.distribute_dataset(...)? - # This is not needed, because distribute_dataset shards the dataset - # across hosts. - # if num_processes > 1: - # dataset = dataset.shard(num_processes, process_id) - dataset = tf.data.TFRecordDataset( dataset, buffer_size=None, @@ -189,9 +184,10 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): ) dataset = dataset.unbatch() - # Shuffle dataset if in training mode. + # Shuffle dataset if in training mode. Pass a seed so that all processes + # have the same shuffle. if self.training and shuffle_buffer and shuffle_buffer > 0: - dataset = dataset.shuffle(shuffle_buffer) + dataset = dataset.shuffle(shuffle_buffer, seed=SEED) dataset = dataset.batch( self.batch_size, From 453e1693081c9ed09a159ecf523577fe935b6c14 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:24:02 +0530 Subject: [PATCH 072/279] Reduce bsz --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 3616008f..0a38413a 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 128 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 From 9bf0574d18822641b3cac03fa6f3dc9b4f4a089d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:28:12 +0530 Subject: [PATCH 073/279] Increase bsz again --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 0a38413a..3616008f 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 128 +training_config.global_batch_size = 16896 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 From c6be1fb0c8d737524e54e80bda413655c31c3a46 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:37:18 +0530 Subject: [PATCH 074/279] Increase bsz again --- examples/ml_perf/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index fe6a8cea..6eb60d17 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -13,6 +13,7 @@ SEED = 1337 keras.utils.set_random_seed(SEED) +keras.config.disable_traceback_filtering() def main( From 85779028a99f887896975ff453ac11efe1cd1d32 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:40:30 +0530 Subject: [PATCH 075/279] Disable pbar --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 6eb60d17..1cf8f7f7 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -170,7 +170,7 @@ def generator(dataset, training=False): break # # Train the model. - model.fit(train_generator, steps_per_epoch=num_steps) + model.fit(train_generator, steps_per_epoch=num_steps, verbose=0) if __name__ == "__main__": From 293b1ee029a41f1c97f4657552d288567d69ff4d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:47:59 +0530 Subject: [PATCH 076/279] Add basic metric logger --- examples/ml_perf/main.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 1cf8f7f7..fe97811d 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -15,6 +15,10 @@ keras.utils.set_random_seed(SEED) keras.config.disable_traceback_filtering() +class MetricLogger(keras.callbacks.Callback): + def on_train_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + print("...Training: end of batch {}; got log keys: {}".format(batch, keys)) def main( file_pattern, @@ -170,7 +174,12 @@ def generator(dataset, training=False): break # # Train the model. - model.fit(train_generator, steps_per_epoch=num_steps, verbose=0) + model.fit( + train_generator, + steps_per_epoch=num_steps, + callbacks=[MetricLogger()], + verbose=0, + ) if __name__ == "__main__": From d84be134421f0448b83f2846d4334efb29d5f2a7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:51:59 +0530 Subject: [PATCH 077/279] Add basic metric logger --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index fe97811d..5a89fd19 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -18,7 +18,7 @@ class MetricLogger(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): keys = list(logs.keys()) - print("...Training: end of batch {}; got log keys: {}".format(batch, keys)) + print("--->", logs["loss"]) def main( file_pattern, From bd68100e9b7cc734e28f5f15e51dda9b1d6ad397 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 00:57:05 +0530 Subject: [PATCH 078/279] Reduce bsz --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 3616008f..c6bd6adc 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 1024 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 From b552276a0621ea8243c6fdf8e5c17119df3f6f7e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 01:04:22 +0530 Subject: [PATCH 079/279] Reduce bsz --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index c6bd6adc..89e38153 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 1024 +training_config.global_batch_size = 4096 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 From 898c8745388ee0cb6e7defa9bfa2e84eaec23840 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 07:49:51 +0530 Subject: [PATCH 080/279] Reduce bsz --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 89e38153..35688a36 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 4096 +training_config.global_batch_size = 8192 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 From 901bed0ba1c25e4069b09161eeb7c6d11b11d0ec Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 07:53:07 +0530 Subject: [PATCH 081/279] Reduce bsz --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 35688a36..7837b7d9 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -192,7 +192,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 8192 +training_config.global_batch_size = 16384 # Set `num_steps` in the main config file instead of num_epochs, because we are # using a Python generator. training_config.num_steps = 20 From 82fe81dbc8901af090acee91a801cfb81c512958 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 09:37:50 +0530 Subject: [PATCH 082/279] Add cfg for v6e-16 full dataset --- .../ml_perf/configs/v6e_16_full_dataset.py | 215 ++++++++++++++++++ examples/ml_perf/main.py | 2 - 2 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 examples/ml_perf/configs/v6e_16_full_dataset.py diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py new file mode 100644 index 00000000..9ca075e6 --- /dev/null +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -0,0 +1,215 @@ +from keras.utils import Config + +config = Config() + +# === Experiment metadata === +config.experiment_name = "v6e_16_full_dataset" +config.model_dir = "./v6e_16_full_dataset" + +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-0000[0-3]-of-01024tfrecord" +) +dataset_config.val_file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-0000[0-3]-of-01024tfrecord" +) +# The path which we are reading from already has the batched dataset. +dataset_config.file_batch_size = 4224 + +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] +dataset_config.lookup = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "feature_list_length": 3, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "feature_list_length": 1, + "new_name": "cat_16", + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "cat_18", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "cat_19", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "cat_20", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "cat_21", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "cat_22", + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "cat_25", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_26", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_27", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_28", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_29", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_30", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_31", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_32", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "cat_33", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "cat_34", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "cat_35", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "cat_36", + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "cat_37", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "cat_38", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "cat_39", + }, +] + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 8192 +model_config.max_unique_ids_per_partition = 4096 +model_config.learning_rate = 0.0034 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 + +# === Training === +training_config = Config() +training_config.learning_rate = 0.0034 +training_config.global_batch_size = 128 +# Set `num_steps` instead of `num_epochs`, because we are using a Python +# generator. +training_config.num_steps = 10 + +# === Assign all configs to the root config === +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 5a89fd19..b55450c8 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -48,8 +48,6 @@ def main( keras.distribution.set_distribution(distribution) num_processes = distribution._num_process - per_host_batch_size = global_batch_size // num_processes - # === Distributed embeddings' configs for lookup features === # Set XLA flags. os.environ["XLA_FLAGS"] = ( From 107c79d9c82bc07d998a5e582e5e9ba13e8e5822 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 10:24:25 +0530 Subject: [PATCH 083/279] Debug --- examples/ml_perf/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 1f4f170a..2ebc084b 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -171,6 +171,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) + print(dataset) dataset = tf.data.TFRecordDataset( dataset, From 73414ebc0968edda9e70e6810182e4e8fa7c0669 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 10:24:38 +0530 Subject: [PATCH 084/279] Debug --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 2ebc084b..92bb4577 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -171,7 +171,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - print(dataset) + print("------------------>", dataset) dataset = tf.data.TFRecordDataset( dataset, From 2ed102c4478e3bc47b08b286d30b1c4f330154e6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 11:04:49 +0530 Subject: [PATCH 085/279] Debug --- examples/ml_perf/main.py | 14 +++++++------- .../layers/embedding/jax/distributed_embedding.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index b55450c8..a3576f8e 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -140,9 +140,9 @@ def main( distribution.auto_shard_dataset = False # Print one sample. - for element in train_ds.take(1): - print(">>> train sample", element[0]["large_emb_inputs"]['cat_33_id']) - print(">>> train sample", element[0]["small_emb_inputs"]['cat_39_id']) + # for element in train_ds.take(1): + # print(">>> train sample", element[0]["large_emb_inputs"]['cat_33_id']) + # print(">>> train sample", element[0]["small_emb_inputs"]['cat_39_id']) def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses @@ -163,10 +163,10 @@ def generator(dataset, training=False): train_generator = generator(train_ds, training=True) # # eval_generator = generator(eval_ds, training=False) - for first_batch in train_generator: - print("------>dense", first_batch[0]["dense_input"]) - print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) - print("-------> large", first_batch[0]["large_emb_inputs"]) + # for first_batch in train_generator: + # print("------>dense", first_batch[0]["dense_input"]) + # print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) + # print("-------> large", first_batch[0]["large_emb_inputs"]) # model(first_batch[0]) break diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1697a497..ba2b92a2 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -580,18 +580,18 @@ def _sparsecore_preprocess( ) layout = self._sparsecore_layout - print(f"-->{layout=}") + # print(f"-->{layout=}") mesh = layout.device_mesh.backend_mesh - print(f"-->{mesh=}") + # print(f"-->{mesh=}") global_device_count = mesh.devices.size - print(f"-->{global_device_count=}") + # print(f"-->{global_device_count=}") local_device_count = mesh.local_mesh.devices.size - print(f"{local_device_count=}") + # print(f"{local_device_count=}") num_sc_per_device = jte_utils.num_sparsecores_per_device( mesh.devices.item(0) ) - print(f"-->{num_sc_per_device=}") - print(f"-->{jax.process_count()=}") + # print(f"-->{num_sc_per_device=}") + # print(f"-->{jax.process_count()=}") preprocessed, stats = embedding_utils.stack_and_shard_samples( self._config.feature_specs, @@ -600,7 +600,7 @@ def _sparsecore_preprocess( global_device_count, num_sc_per_device, ) - print(f"-->{stats=}") + # print(f"-->{stats=}") # if training: # # Synchronize input statistics across all devices and update the From dc189657555cf32f3968798e9028b59ae93967ec Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 11:06:50 +0530 Subject: [PATCH 086/279] Debug --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a3576f8e..d0da3841 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -168,8 +168,8 @@ def generator(dataset, training=False): # print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) # print("-------> large", first_batch[0]["large_emb_inputs"]) - # model(first_batch[0]) - break + # # model(first_batch[0]) + # break # # Train the model. model.fit( From 83b9fe074a1a5cc6e35184fcaa5f41ab17d13b03 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 11:09:34 +0530 Subject: [PATCH 087/279] Debug --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 92bb4577..bc72facf 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -171,7 +171,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - print("------------------>", dataset) + print("------------------>", [d for d in dataset]) dataset = tf.data.TFRecordDataset( dataset, From 3492237e2129509e26344b1d2e154b5593e0cfb5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 11:39:44 +0530 Subject: [PATCH 088/279] Comment out code --- .../ml_perf/configs/v6e_16_full_dataset.py | 2 +- .../ml_perf/configs/v6e_8_full_dataset.py | 1 - .../embedding/distributed_embedding_test.py | 1 - .../embedding/jax/distributed_embedding.py | 94 +++++++++---------- .../layers/embedding/jax/embedding_utils.py | 1 - 5 files changed, 48 insertions(+), 51 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 9ca075e6..a1991972 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 128 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index 30a59f97..e7867653 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -203,7 +203,6 @@ training_config = Config() training_config.learning_rate = 0.0034 training_config.global_batch_size = 128 -training_config.batch_size = 256 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 2 diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index cb4df82f..f5cd3192 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -86,7 +86,6 @@ def setUp(self): self._strategy = tf.distribute.TPUStrategy( resolver, experimental_device_assignment=device_assignment ) - print("### num_replicas", self._strategy.num_replicas_in_sync) self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) elif keras.backend.backend() == "jax" and self.on_tpu: self._strategy = JaxDummyStrategy() diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index bd69a0be..4bee8fc0 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -596,53 +596,53 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # aggregated_stats = jax.tree.map( + # lambda x: jnp.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 65bc43ec..80e342dd 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -211,7 +211,6 @@ def collect_tokens_and_weights( jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples) - print("BATCH SIZE PER SC", feature_tokens[0].shape[0] // num_sc_per_device) preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input( feature_tokens, feature_weights, From 4cf06fdf3a48bed8b37976d368a64c771ad761fe Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 13:10:07 +0530 Subject: [PATCH 089/279] Add logging statements --- examples/ml_perf/dataloader.py | 18 +++++++-- examples/ml_perf/main.py | 67 +++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index bc72facf..406510d4 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf +from absl import logging + SEED = 1337 class DataLoader: @@ -15,6 +17,9 @@ def __init__( label, training=False, ): + passed_args = locals() + logging.debug("Initialising DataLoader with: %s", passed_args) + # Passed attributes. self.file_pattern = file_pattern self.batch_size = batch_size @@ -27,11 +32,13 @@ def __init__( # Derived attributes. self._return_dummy_dataset = file_pattern is None - # self._per_host_batch_size = self.batch_size // jax.process_count() + if self._return_dummy_dataset: + logging.warning( + "`file_pattern` is `None`. Will use the dummy dataset." + ) def _get_dummy_batch(self): """Returns a dummy batch of data in the final desired structure.""" - # Labels data = { "clicked": np.random.randint( @@ -87,6 +94,7 @@ def _get_dummy_batch(self): def _create_dummy_dataset(self): """Creates a TF dummy dataset (randomly initialised).""" + logging.info("=== Creating dummy dataset ===") dummy_data = self._get_dummy_batch() # Separate labels from features to create a `(features, labels)` tuple. @@ -165,13 +173,17 @@ def _get_emb_inputs(emb_features): return (x, labels) def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): + passed_args = locals() + logging.debug("Called `create_dataset` with:%s", passed_args) + if self._return_dummy_dataset: return self._create_dummy_dataset() + logging.info("=== Loading the real dataset from files ===") # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - print("------------------>", [d for d in dataset]) + logging.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( dataset, diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d0da3841..0e95f6ae 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -2,24 +2,29 @@ import importlib import os +from absl import logging + os.environ["KERAS_BACKEND"] = "jax" import keras - import keras_rs from .dataloader import DataLoader from .model import DLRMDCNV2 +# Set random seed. SEED = 1337 keras.utils.set_random_seed(SEED) +# Disable traceback filtering in case the script errors out. keras.config.disable_traceback_filtering() + class MetricLogger(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("--->", logs["loss"]) + def main( file_pattern, val_file_pattern, @@ -29,7 +34,6 @@ def main( label, shuffle_buffer, embedding_dim, - allow_id_dropping, max_ids_per_partition, max_unique_ids_per_partition, embedding_learning_rate, @@ -42,20 +46,21 @@ def main( file_batch_size, num_steps, ): + passed_args = locals() + logging.debug("`main()` called with: %s", passed_args) + # Set DDP as Keras distribution strategy devices = keras.distribution.list_devices(device_type="tpu") distribution = keras.distribution.DataParallel(devices=devices) keras.distribution.set_distribution(distribution) num_processes = distribution._num_process + logging.info("Initialized distribution strategy.") + logging.info("Found %d devices.", len(devices)) + logging.info("Running with %d processes.", num_processes) + if distribution._process_id is not None: + logging.info("Current Process ID: %d", distribution._process_id) # === Distributed embeddings' configs for lookup features === - # Set XLA flags. - os.environ["XLA_FLAGS"] = ( - "--xla_sparse_core_max_ids_per_partition_per_sample=" - f"{max_ids_per_partition} " - "--xla_sparse_core_max_unique_ids_per_partition_per_sample=" - f"{max_unique_ids_per_partition}" - ) feature_configs = {} for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] @@ -96,7 +101,7 @@ def main( # We instantiate the model first, because we need to preprocess large # embedding feature inputs using the distributed embedding layer defined # inside the model class. - print("===== Initialising model =====") + logging.info("===== Initialising model =====") model = DLRMDCNV2( large_emb_feature_configs=feature_configs, small_emb_features=small_emb_features, @@ -114,9 +119,10 @@ def main( optimizer=keras.optimizers.Adagrad(learning_rate=learning_rate), metrics=[keras.metrics.BinaryAccuracy()], ) + logging.info("Initialised model:\n%s", model) # === Load dataset === - print("===== Loading dataset =====") + logging.info("===== Loading dataset =====") train_ds = DataLoader( file_pattern=file_pattern, batch_size=global_batch_size, @@ -139,11 +145,6 @@ def main( # eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False - # Print one sample. - # for element in train_ds.take(1): - # print(">>> train sample", element[0]["large_emb_inputs"]['cat_33_id']) - # print(">>> train sample", element[0]["small_emb_inputs"]['cat_39_id']) - def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses large embedding features. @@ -161,29 +162,31 @@ def generator(dataset, training=False): y = labels yield (x, y) + logging.info("=== Preprocessing large embedding tables ===") train_generator = generator(train_ds, training=True) - # # eval_generator = generator(eval_ds, training=False) - # for first_batch in train_generator: - # print("------>dense", first_batch[0]["dense_input"]) - # print("-------> small", first_batch[0]["small_emb_inputs"]['cat_39_id']) - # print("-------> large", first_batch[0]["large_emb_inputs"]) - - # # model(first_batch[0]) - # break + # eval_generator = generator(eval_ds, training=False) + logging.debug("Inspecting one batch of data...") + for first_batch in train_generator: + logging.debug("Dense inputs:%s", first_batch[0]["dense_input"]) + logging.debug("Small embedding inputs:%s", first_batch[0]["small_emb_inputs"]['cat_39_id']) + logging.debug("Large embedding inputs:%s", first_batch[0]["large_emb_inputs"]) + break + logging.info("=== Successfully preprocessed one batch of data ===") - # # Train the model. + # === Training === + logging.info("===== Training =====") model.fit( train_generator, steps_per_epoch=num_steps, callbacks=[MetricLogger()], - verbose=0, ) + logging.info("Training finished.") if __name__ == "__main__": keras.config.disable_traceback_filtering() - print("====== Launching train script =======") + logging.info("====== Launching train script =======") parser = argparse.ArgumentParser( description=( "Benchmark the DLRM-DCNv2 model on the Criteo dataset (MLPerf)" @@ -194,10 +197,11 @@ def generator(dataset, training=False): ) args = parser.parse_args() - print(f"===== Reading config from {args.config_name} ======") + logging.info("===== Reading config from %s ======", args.config_name) config = importlib.import_module( f".configs.{args.config_name}", package=__package__ ).config + logging.info("Config:\n%s", config) # === Unpack args from config === @@ -219,7 +223,6 @@ def generator(dataset, training=False): model_cfg = config["model"] # Embedding embedding_dim = model_cfg["embedding_dim"] - allow_id_dropping = model_cfg["allow_id_dropping"] embedding_threshold = model_cfg["embedding_threshold"] max_ids_per_partition = model_cfg["max_ids_per_partition"] max_unique_ids_per_partition = model_cfg["max_unique_ids_per_partition"] @@ -251,6 +254,9 @@ def generator(dataset, training=False): else: large_emb_features.append(emb_feature) + logging.debug("Large Embedding Features: %s", large_emb_features) + logging.debug("Small Embedding Features: %s", small_emb_features) + main( file_pattern, val_file_pattern, @@ -260,7 +266,6 @@ def generator(dataset, training=False): label, shuffle_buffer, embedding_dim, - allow_id_dropping, max_ids_per_partition, max_unique_ids_per_partition, embedding_learning_rate, @@ -273,3 +278,5 @@ def generator(dataset, training=False): file_batch_size, num_steps, ) + + logging.info("Train script finished") From 19fa3f57ffadecebb185728f075a26d1a3052c32 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 13:12:31 +0530 Subject: [PATCH 090/279] Format --- examples/ml_perf/dataloader.py | 2 +- examples/ml_perf/main.py | 11 ++++++++--- .../src/layers/embedding/jax/distributed_embedding.py | 1 - 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 406510d4..012e5da7 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -1,10 +1,10 @@ import numpy as np import tensorflow as tf - from absl import logging SEED = 1337 + class DataLoader: def __init__( self, diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 0e95f6ae..6bbe6af4 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -7,6 +7,7 @@ os.environ["KERAS_BACKEND"] = "jax" import keras + import keras_rs from .dataloader import DataLoader @@ -21,7 +22,6 @@ class MetricLogger(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): - keys = list(logs.keys()) print("--->", logs["loss"]) @@ -168,8 +168,13 @@ def generator(dataset, training=False): logging.debug("Inspecting one batch of data...") for first_batch in train_generator: logging.debug("Dense inputs:%s", first_batch[0]["dense_input"]) - logging.debug("Small embedding inputs:%s", first_batch[0]["small_emb_inputs"]['cat_39_id']) - logging.debug("Large embedding inputs:%s", first_batch[0]["large_emb_inputs"]) + logging.debug( + "Small embedding inputs:%s", + first_batch[0]["small_emb_inputs"]["cat_39_id"], + ) + logging.debug( + "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] + ) break logging.info("=== Successfully preprocessed one batch of data ===") diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 4bee8fc0..a6290da8 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -9,7 +9,6 @@ import numpy as np from jax import numpy as jnp from jax.experimental import layout as jax_layout -from jax.experimental import multihost_utils from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn import ( From f89a6a914cb4118613b0d2adc63ea45517dd39f4 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 13:51:05 +0530 Subject: [PATCH 091/279] Use Py logger instead of absl --- examples/ml_perf/dataloader.py | 3 ++- examples/ml_perf/main.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 012e5da7..37552772 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -1,6 +1,7 @@ +import logging + import numpy as np import tensorflow as tf -from absl import logging SEED = 1337 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 6bbe6af4..cbfb8428 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -1,9 +1,8 @@ import argparse import importlib +import logging import os -from absl import logging - os.environ["KERAS_BACKEND"] = "jax" import keras From 6fee9dce7b502d7ee0854afd354510b12a95ebf5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 18:13:39 +0530 Subject: [PATCH 092/279] Clean up config passing --- examples/ml_perf/dataloader.py | 12 +- examples/ml_perf/main.py | 204 ++++++++++++--------------------- 2 files changed, 82 insertions(+), 134 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 37552772..c0e53483 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -5,6 +5,8 @@ SEED = 1337 +logger = logging.getLogger(__name__) + class DataLoader: def __init__( @@ -19,7 +21,7 @@ def __init__( training=False, ): passed_args = locals() - logging.debug("Initialising DataLoader with: %s", passed_args) + logger.debug("Initialising `DataLoader` with: %s", passed_args) # Passed attributes. self.file_pattern = file_pattern @@ -95,7 +97,7 @@ def _get_dummy_batch(self): def _create_dummy_dataset(self): """Creates a TF dummy dataset (randomly initialised).""" - logging.info("=== Creating dummy dataset ===") + logger.info("=== Creating dummy dataset ===") dummy_data = self._get_dummy_batch() # Separate labels from features to create a `(features, labels)` tuple. @@ -175,16 +177,16 @@ def _get_emb_inputs(emb_features): def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): passed_args = locals() - logging.debug("Called `create_dataset` with:%s", passed_args) + logger.debug("Called `create_dataset` with:%s", passed_args) if self._return_dummy_dataset: return self._create_dummy_dataset() - logging.info("=== Loading the real dataset from files ===") + logger.info("=== Loading the real dataset from files ===") # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - logging.info("List of input files: %s", [f for f in dataset]) + logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( dataset, diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index cbfb8428..7becd4be 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -14,8 +14,10 @@ # Set random seed. SEED = 1337 + +logger = logging.getLogger(__name__) + keras.utils.set_random_seed(SEED) -# Disable traceback filtering in case the script errors out. keras.config.disable_traceback_filtering() @@ -25,41 +27,42 @@ def on_train_batch_end(self, batch, logs=None): def main( - file_pattern, - val_file_pattern, - dense_features, - large_emb_features, - small_emb_features, - label, - shuffle_buffer, - embedding_dim, - max_ids_per_partition, - max_unique_ids_per_partition, - embedding_learning_rate, - bottom_mlp_dims, - top_mlp_dims, - num_dcn_layers, - dcn_projection_dim, - learning_rate, - global_batch_size, - file_batch_size, - num_steps, + ds_cfg, + model_cfg, + training_cfg, ): passed_args = locals() - logging.debug("`main()` called with: %s", passed_args) + logger.debug("Called `main()` with: %s", passed_args) # Set DDP as Keras distribution strategy devices = keras.distribution.list_devices(device_type="tpu") distribution = keras.distribution.DataParallel(devices=devices) keras.distribution.set_distribution(distribution) num_processes = distribution._num_process - logging.info("Initialized distribution strategy.") - logging.info("Found %d devices.", len(devices)) - logging.info("Running with %d processes.", num_processes) + logger.info("Initialized distribution strategy.") + logger.info("Found %d devices.", len(devices)) + logger.info("Running with %d processes.", num_processes) if distribution._process_id is not None: - logging.info("Current Process ID: %d", distribution._process_id) + logger.info("Current Process ID: %d", distribution._process_id) # === Distributed embeddings' configs for lookup features === + + # For features which have vocabulary_size < embedding_threshold, we can + # just do a normal dense lookup for those instead of having distributed + # embeddings. We could ideally pass `placement = default_device` to + # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this + # separation of features), but doing it that way will necessarily require + # a separate optimiser for the embedding layer. + small_emb_features = [] + large_emb_features = [] + for emb_feature in ds_cfg.lookup: + if emb_feature["vocabulary_size"] < model_cfg.embedding_threshold: + small_emb_features.append(emb_feature) + else: + large_emb_features.append(emb_feature) + logger.debug("Large Embedding Features: %s", large_emb_features) + logger.debug("Small Embedding Features: %s", small_emb_features) + feature_configs = {} for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] @@ -69,7 +72,7 @@ def main( table_config = keras_rs.layers.TableConfig( name=f"{feature_name}_table", vocabulary_size=vocabulary_size, - embedding_dim=embedding_dim, + embedding_dim=model_cfg.embedding_dim, # TODO(abheesht): Verify. initializer=keras.initializers.VarianceScaling( scale=1.0, @@ -78,63 +81,68 @@ def main( seed=SEED, ), optimizer=keras.optimizers.Adagrad( - learning_rate=embedding_learning_rate + learning_rate=model_cfg.learning_rate ), combiner="sum", placement="sparsecore", # TODO: These two args are not getting passed down to # `jax-tpu-embedding` properly, seems like. - max_ids_per_partition=max_ids_per_partition, - max_unique_ids_per_partition=max_unique_ids_per_partition, + max_ids_per_partition=model_cfg.max_ids_per_partition, + max_unique_ids_per_partition=model_cfg.max_unique_ids_per_partition, ) feature_configs[f"{feature_name}_id"] = keras_rs.layers.FeatureConfig( name=feature_name, table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or # `(bsz, feature_list_length)`. The original example uses 1. - input_shape=(global_batch_size, 1), - output_shape=(global_batch_size, embedding_dim), + input_shape=(training_cfg.global_batch_size, 1), + output_shape=( + training_cfg.global_batch_size, + model_cfg.embedding_dim, + ), ) # === Instantiate model === # We instantiate the model first, because we need to preprocess large # embedding feature inputs using the distributed embedding layer defined # inside the model class. - logging.info("===== Initialising model =====") + logger.info("===== Initialising model =====") model = DLRMDCNV2( large_emb_feature_configs=feature_configs, small_emb_features=small_emb_features, - embedding_dim=embedding_dim, - bottom_mlp_dims=bottom_mlp_dims, - top_mlp_dims=top_mlp_dims, - num_dcn_layers=num_dcn_layers, - dcn_projection_dim=dcn_projection_dim, + embedding_dim=model_cfg.embedding_dim, + bottom_mlp_dims=model_cfg.bottom_mlp_dims, + top_mlp_dims=model_cfg.top_mlp_dims, + num_dcn_layers=model_cfg.num_dcn_layers, + dcn_projection_dim=model_cfg.dcn_projection_dim, seed=SEED, dtype="float32", name="dlrm_dcn_v2", ) model.compile( loss=keras.losses.BinaryCrossentropy(), - optimizer=keras.optimizers.Adagrad(learning_rate=learning_rate), + optimizer=keras.optimizers.Adagrad( + learning_rate=training_cfg.learning_rate + ), metrics=[keras.metrics.BinaryAccuracy()], ) - logging.info("Initialised model:\n%s", model) + logger.info("Initialised model:\n%s", model) # === Load dataset === - logging.info("===== Loading dataset =====") + logger.info("===== Loading dataset =====") train_ds = DataLoader( - file_pattern=file_pattern, - batch_size=global_batch_size, - file_batch_size=file_batch_size, - dense_features=dense_features, + file_pattern=ds_cfg.file_pattern, + batch_size=training_cfg.global_batch_size, + file_batch_size=ds_cfg.get("file_batch_size", None), + dense_features=ds_cfg.dense_features, large_emb_features=large_emb_features, small_emb_features=small_emb_features, - label=label, + label=ds_cfg.label, training=True, ).create_dataset( process_id=distribution._process_id, num_processes=num_processes, - shuffle_buffer=shuffle_buffer, + shuffle_buffer=ds_cfg.get("shuffle_buffer", None), ) # For the multi-host case, the dataset has to be distributed manually. # See note here: @@ -161,36 +169,40 @@ def generator(dataset, training=False): y = labels yield (x, y) - logging.info("=== Preprocessing large embedding tables ===") + logger.info("=== Preprocessing large embedding tables ===") train_generator = generator(train_ds, training=True) # eval_generator = generator(eval_ds, training=False) - logging.debug("Inspecting one batch of data...") + logger.debug("Inspecting one batch of data...") for first_batch in train_generator: - logging.debug("Dense inputs:%s", first_batch[0]["dense_input"]) - logging.debug( + logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) + logger.debug( "Small embedding inputs:%s", first_batch[0]["small_emb_inputs"]["cat_39_id"], ) - logging.debug( + logger.debug( "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] ) break - logging.info("=== Successfully preprocessed one batch of data ===") + logger.info("=== Successfully preprocessed one batch of data ===") # === Training === - logging.info("===== Training =====") + logger.info("===== Training =====") model.fit( train_generator, - steps_per_epoch=num_steps, + steps_per_epoch=training_cfg.num_steps, callbacks=[MetricLogger()], ) - logging.info("Training finished.") + logger.info("Training finished.") if __name__ == "__main__": - keras.config.disable_traceback_filtering() + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", + datefmt="%H:%M:%S", + ) - logging.info("====== Launching train script =======") + logger.info("===== Launching train script =====") parser = argparse.ArgumentParser( description=( "Benchmark the DLRM-DCNv2 model on the Criteo dataset (MLPerf)" @@ -201,86 +213,20 @@ def generator(dataset, training=False): ) args = parser.parse_args() - logging.info("===== Reading config from %s ======", args.config_name) + logger.info("===== Reading config from %s ======", args.config_name) config = importlib.import_module( f".configs.{args.config_name}", package=__package__ ).config - logging.info("Config:\n%s", config) + logger.info("Config:\n%s", config) - # === Unpack args from config === - - # == Dataset config == ds_cfg = config["dataset"] - # File path - file_pattern = ds_cfg["file_pattern"] - val_file_pattern = ds_cfg.get("val_file_pattern", None) - # File batch size - file_batch_size = ds_cfg.get("file_batch_size", None) - # Shuffling - shuffle_buffer = ds_cfg.get("shuffle_buffer", None) - # Features - label = ds_cfg["label"] - dense_features = ds_cfg["dense"] - emb_features = ds_cfg["lookup"] - - # == Model config == model_cfg = config["model"] - # Embedding - embedding_dim = model_cfg["embedding_dim"] - embedding_threshold = model_cfg["embedding_threshold"] - max_ids_per_partition = model_cfg["max_ids_per_partition"] - max_unique_ids_per_partition = model_cfg["max_unique_ids_per_partition"] - embedding_learning_rate = model_cfg["learning_rate"] - # MLP - bottom_mlp_dims = model_cfg["bottom_mlp_dims"] - top_mlp_dims = model_cfg["top_mlp_dims"] - # DCN - num_dcn_layers = model_cfg["num_dcn_layers"] - dcn_projection_dim = model_cfg["dcn_projection_dim"] - - # == Training config == training_cfg = config["training"] - learning_rate = training_cfg["learning_rate"] - global_batch_size = training_cfg["global_batch_size"] - num_steps = training_cfg["num_steps"] - - # For features which have vocabulary_size < embedding_threshold, we can - # just do a normal dense lookup for those instead of having distributed - # embeddings. We could ideally pass `placement = default_device` to - # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this - # separation of features), but doing it that way will necessarily require - # a separate optimiser for the embedding layer. - small_emb_features = [] - large_emb_features = [] - for emb_feature in emb_features: - if emb_feature["vocabulary_size"] < embedding_threshold: - small_emb_features.append(emb_feature) - else: - large_emb_features.append(emb_feature) - - logging.debug("Large Embedding Features: %s", large_emb_features) - logging.debug("Small Embedding Features: %s", small_emb_features) main( - file_pattern, - val_file_pattern, - dense_features, - large_emb_features, - small_emb_features, - label, - shuffle_buffer, - embedding_dim, - max_ids_per_partition, - max_unique_ids_per_partition, - embedding_learning_rate, - bottom_mlp_dims, - top_mlp_dims, - num_dcn_layers, - dcn_projection_dim, - learning_rate, - global_batch_size, - file_batch_size, - num_steps, + ds_cfg=ds_cfg, + model_cfg=model_cfg, + training_cfg=training_cfg, ) - logging.info("Train script finished") + logger.info("Train script finished") From 0130641e4302df26816b28496116457f4c5047f1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 18:14:28 +0530 Subject: [PATCH 093/279] Fix warning log call --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index c0e53483..3520325b 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -36,7 +36,7 @@ def __init__( # Derived attributes. self._return_dummy_dataset = file_pattern is None if self._return_dummy_dataset: - logging.warning( + logger.warning( "`file_pattern` is `None`. Will use the dummy dataset." ) From 3d16e12b765082583abbcfe24f6227c023c17598 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 18:25:58 +0530 Subject: [PATCH 094/279] Fix cfg access in some places --- examples/ml_perf/dataloader.py | 4 ++-- examples/ml_perf/main.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 3520325b..5fbc1569 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -97,7 +97,7 @@ def _get_dummy_batch(self): def _create_dummy_dataset(self): """Creates a TF dummy dataset (randomly initialised).""" - logger.info("=== Creating dummy dataset ===") + logger.info("Creating dummy dataset...") dummy_data = self._get_dummy_batch() # Separate labels from features to create a `(features, labels)` tuple. @@ -182,7 +182,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): if self._return_dummy_dataset: return self._create_dummy_dataset() - logger.info("=== Loading the real dataset from files ===") + logger.info("Loading the real dataset from files...") # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 7becd4be..120f17f3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -106,7 +106,7 @@ def main( # We instantiate the model first, because we need to preprocess large # embedding feature inputs using the distributed embedding layer defined # inside the model class. - logger.info("===== Initialising model =====") + logger.info("Initialising model...") model = DLRMDCNV2( large_emb_feature_configs=feature_configs, small_emb_features=small_emb_features, @@ -126,15 +126,15 @@ def main( ), metrics=[keras.metrics.BinaryAccuracy()], ) - logger.info("Initialised model:\n%s", model) + logger.info("Initialised model: %s", model) # === Load dataset === - logger.info("===== Loading dataset =====") + logger.info("Loading dataset...") train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, batch_size=training_cfg.global_batch_size, file_batch_size=ds_cfg.get("file_batch_size", None), - dense_features=ds_cfg.dense_features, + dense_features=ds_cfg.dense, large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, @@ -169,7 +169,7 @@ def generator(dataset, training=False): y = labels yield (x, y) - logger.info("=== Preprocessing large embedding tables ===") + logger.info("Preprocessing large embedding tables...") train_generator = generator(train_ds, training=True) # eval_generator = generator(eval_ds, training=False) logger.debug("Inspecting one batch of data...") @@ -183,10 +183,10 @@ def generator(dataset, training=False): "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] ) break - logger.info("=== Successfully preprocessed one batch of data ===") + logger.info("Successfully preprocessed one batch of data") # === Training === - logger.info("===== Training =====") + logger.info("Training...") model.fit( train_generator, steps_per_epoch=training_cfg.num_steps, @@ -202,7 +202,7 @@ def generator(dataset, training=False): datefmt="%H:%M:%S", ) - logger.info("===== Launching train script =====") + logger.info("Launching train script...") parser = argparse.ArgumentParser( description=( "Benchmark the DLRM-DCNv2 model on the Criteo dataset (MLPerf)" @@ -213,11 +213,11 @@ def generator(dataset, training=False): ) args = parser.parse_args() - logger.info("===== Reading config from %s ======", args.config_name) + logger.info("Reading config from %s", args.config_name) config = importlib.import_module( f".configs.{args.config_name}", package=__package__ ).config - logger.info("Config:\n%s", config) + logger.info("Config: %s", config) ds_cfg = config["dataset"] model_cfg = config["model"] From 9f184fcd94671a59ec6c64fe12bad262c88e1e3f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 18:45:10 +0530 Subject: [PATCH 095/279] Clean up logging --- examples/ml_perf/main.py | 2 +- examples/ml_perf/model.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 120f17f3..dd94537a 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -192,7 +192,7 @@ def generator(dataset, training=False): steps_per_epoch=training_cfg.num_steps, callbacks=[MetricLogger()], ) - logger.info("Training finished.") + logger.info("Training finished") if __name__ == "__main__": diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 4f84bbbf..2e56576c 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -1,3 +1,4 @@ +import logging from typing import Any, TypeAlias import keras @@ -7,6 +8,8 @@ Tensor: TypeAlias = Any +logger = logging.getLogger(__name__) + def _clone_initializer( initializer: keras.initializers.Initializer, @@ -90,6 +93,10 @@ def __init__( name: The name of the layer. """ super().__init__(dtype=dtype, name=name, **kwargs) + + passed_args = locals() + logger.debug("Initialising `DLRMDCNV2` with: %s", passed_args) + self.seed = seed # === Layers ==== @@ -103,6 +110,7 @@ def __init__( ), name="bottom_mlp", ) + logging.debug("Initialised Bottom MLP: %s", self.bottom_mlp) # Distributed embeddings for large embedding tables self.embedding_layer = keras_rs.layers.DistributedEmbedding( feature_configs=large_emb_feature_configs, @@ -110,6 +118,9 @@ def __init__( dtype=dtype, name="embedding_layer", ) + logging.debug( + "Initialised `DistributedEmbedding` layer: %s", self.embedding_layer + ) # Embedding layers for small embedding tables self.small_embedding_layers = None if small_emb_features: @@ -124,6 +135,10 @@ def __init__( ) for i, small_emb_feature in enumerate(small_emb_features) ] + logging.debug( + "Initialised small embedding layers: %s", + self.small_embedding_layers, + ) # DCN for "interactions" self.dcn_block = DCNBlock( num_layers=num_dcn_layers, @@ -132,6 +147,7 @@ def __init__( dtype=dtype, name="dcn_block", ) + logging.debug("Initialised DCN block: %s", self.dcn_block) # Top MLP for predictions self.top_mlp = keras.Sequential( self._get_mlp_layers( @@ -141,6 +157,7 @@ def __init__( ), name="top_mlp", ) + logging.debug("Initialised Top MLP: %s", self.top_mlp) # === Passed attributes === self.large_emb_feature_configs = large_emb_feature_configs @@ -289,6 +306,9 @@ def __init__( """ super().__init__(dtype=dtype, name=name, **kwargs) + passed_args = locals() + logger.debug("Initialising `DCNBlock` with: %s", passed_args) + # Layers self.layers = [ keras_rs.layers.FeatureCross( From d83e4af44892e166985ba071ee09da4fa79995df Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 22:24:50 +0530 Subject: [PATCH 096/279] Add val logging --- .../ml_perf/configs/v6e_16_full_dataset.py | 1 + examples/ml_perf/main.py | 30 +++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index a1991972..78b400e6 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -206,6 +206,7 @@ # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 +training_config.eval_freq = 5 # === Assign all configs to the root config === config.dataset = dataset_config diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index dd94537a..2959ffde 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -144,12 +144,28 @@ def main( num_processes=num_processes, shuffle_buffer=ds_cfg.get("shuffle_buffer", None), ) + do_eval = False + if ds_cfg.val_file_pattern: + do_eval = True + eval_ds = DataLoader( + file_pattern=ds_cfg.val_file_pattern, + batch_size=training_cfg.global_batch_size, + file_batch_size=ds_cfg.get("file_batch_size", None), + dense_features=ds_cfg.dense, + large_emb_features=large_emb_features, + small_emb_features=small_emb_features, + label=ds_cfg.label, + training=False, + ).create_dataset( + process_id=distribution._process_id, + num_processes=num_processes, + ) # For the multi-host case, the dataset has to be distributed manually. # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: train_ds = distribution.distribute_dataset(train_ds) - # eval_ds = distribution.distribute_dataset(eval_ds) + eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False def generator(dataset, training=False): @@ -171,7 +187,8 @@ def generator(dataset, training=False): logger.info("Preprocessing large embedding tables...") train_generator = generator(train_ds, training=True) - # eval_generator = generator(eval_ds, training=False) + if do_eval: + eval_generator = generator(eval_ds, training=False) logger.debug("Inspecting one batch of data...") for first_batch in train_generator: logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) @@ -187,10 +204,17 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") + # Keras does not have a straightforward way to log at a step-level instead + # of epoch-level. So, we do a workaround here. + steps_per_epoch = training_cfg.eval_freq + epochs = training_cfg.num_steps // training_cfg.eval_freq model.fit( train_generator, - steps_per_epoch=training_cfg.num_steps, + validation_data=eval_generator, + steps_per_epoch=steps_per_epoch, callbacks=[MetricLogger()], + validation_freq=1, + epochs=epochs, ) logger.info("Training finished") From bae26955a476a9c094f327613cb2b2585d52fa4f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 23:47:43 +0530 Subject: [PATCH 097/279] Fix data generator so that data is refreshed correctly --- .../ml_perf/configs/v6e_16_full_dataset.py | 1 + examples/ml_perf/dataloader.py | 9 +++++++ examples/ml_perf/main.py | 25 +++++++++++++------ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 78b400e6..55bcd8e2 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -207,6 +207,7 @@ # generator. training_config.num_steps = 10 training_config.eval_freq = 5 +training_config.num_eval_steps = 10 # === Assign all configs to the root config === config.dataset = dataset_config diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5fbc1569..9aa238ef 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -18,6 +18,7 @@ def __init__( large_emb_features, small_emb_features, label, + num_steps, training=False, ): passed_args = locals() @@ -31,6 +32,7 @@ def __init__( self.large_emb_features = large_emb_features self.small_emb_features = small_emb_features self.label = label + self.num_steps = num_steps self.training = training # Derived attributes. @@ -200,6 +202,9 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): ) dataset = dataset.unbatch() + # Take only `num_steps * self.batch_size` examples. + dataset = dataset.take(self.num_steps * self.batch_size) + # Shuffle dataset if in training mode. Pass a seed so that all processes # have the same shuffle. if self.training and shuffle_buffer and shuffle_buffer > 0: @@ -211,6 +216,10 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): num_parallel_calls=tf.data.AUTOTUNE, ) + # Repeat the dataset infinite number of times so that the generator + # does not run out. + dataset = dataset.repeat() + dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2959ffde..72d5900a 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -130,6 +130,18 @@ def main( # === Load dataset === logger.info("Loading dataset...") + + # Keras does not have a straightforward way to log at a step-level instead + # of epoch-level. So, we do a workaround here. + if ds_cfg.val_file_pattern: + steps_per_epoch = training_cfg.eval_freq + epochs = training_cfg.num_steps // training_cfg.eval_freq + do_eval = True + else: + steps_per_epoch = training_cfg.num_steps + epochs = 1 + do_eval = False + train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, batch_size=training_cfg.global_batch_size, @@ -138,15 +150,14 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, + steps=steps_per_epoch, training=True, ).create_dataset( process_id=distribution._process_id, num_processes=num_processes, shuffle_buffer=ds_cfg.get("shuffle_buffer", None), ) - do_eval = False - if ds_cfg.val_file_pattern: - do_eval = True + if do_eval: eval_ds = DataLoader( file_pattern=ds_cfg.val_file_pattern, batch_size=training_cfg.global_batch_size, @@ -155,6 +166,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, + steps=training_cfg.num_eval_steps, training=False, ).create_dataset( process_id=distribution._process_id, @@ -204,17 +216,14 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") - # Keras does not have a straightforward way to log at a step-level instead - # of epoch-level. So, we do a workaround here. - steps_per_epoch = training_cfg.eval_freq - epochs = training_cfg.num_steps // training_cfg.eval_freq model.fit( train_generator, validation_data=eval_generator, + epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=[MetricLogger()], + validation_steps=training_cfg.num_eval_steps, validation_freq=1, - epochs=epochs, ) logger.info("Training finished") From 5960be146e3b5fe82f14f629e2fbbab7a347620f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 22 Oct 2025 23:50:57 +0530 Subject: [PATCH 098/279] Fix --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 72d5900a..9b6356fe 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -150,7 +150,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, - steps=steps_per_epoch, + num_steps=steps_per_epoch, training=True, ).create_dataset( process_id=distribution._process_id, @@ -166,7 +166,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, - steps=training_cfg.num_eval_steps, + num_steps=training_cfg.num_eval_steps, training=False, ).create_dataset( process_id=distribution._process_id, From 14fcc7c828d5783934064bd6539c704ed0d9c51d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 00:02:31 +0530 Subject: [PATCH 099/279] Allow conditional data repetition --- examples/ml_perf/dataloader.py | 5 ++++- examples/ml_perf/main.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 9aa238ef..49708cd3 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -19,6 +19,7 @@ def __init__( small_emb_features, label, num_steps, + repeat=False, training=False, ): passed_args = locals() @@ -33,6 +34,7 @@ def __init__( self.small_emb_features = small_emb_features self.label = label self.num_steps = num_steps + self.repeat = repeat self.training = training # Derived attributes. @@ -218,7 +220,8 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Repeat the dataset infinite number of times so that the generator # does not run out. - dataset = dataset.repeat() + if self.repeat: + dataset = dataset.repeat() dataset = dataset.prefetch(tf.data.AUTOTUNE) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 9b6356fe..b17f37ba 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -167,6 +167,7 @@ def main( small_emb_features=small_emb_features, label=ds_cfg.label, num_steps=training_cfg.num_eval_steps, + repeat=True, training=False, ).create_dataset( process_id=distribution._process_id, From 8c9797a33d5d275d23f95374ffe91f7ffe21dd50 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 00:25:15 +0530 Subject: [PATCH 100/279] Temporarily remove validation --- examples/ml_perf/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index b17f37ba..7457439e 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -219,12 +219,12 @@ def generator(dataset, training=False): logger.info("Training...") model.fit( train_generator, - validation_data=eval_generator, + # validation_data=eval_generator, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=[MetricLogger()], - validation_steps=training_cfg.num_eval_steps, - validation_freq=1, + # validation_steps=training_cfg.num_eval_steps, + # validation_freq=1, ) logger.info("Training finished") From 37e60feb5cff943284a762b1f3a9482d6ff05bcf Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 01:59:40 +0530 Subject: [PATCH 101/279] Uncomment stat updates --- .../embedding/jax/distributed_embedding.py | 94 +++++++++---------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index a6290da8..d6487a2a 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -595,53 +595,53 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # all_stats = multihost_utils.process_allgather(stats) - # aggregated_stats = jax.tree.map( - # lambda x: jnp.max(x, axis=0), all_stats - # ) - - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + all_stats = multihost_utils.process_allgather(stats) + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # Update configuration and repeat preprocessing if stats changed. + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 03512dd3e950ae234c4241adf3c222b795b07cc3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 02:02:26 +0530 Subject: [PATCH 102/279] Uncomment stat updates --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index d6487a2a..bd69a0be 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -9,6 +9,7 @@ import numpy as np from jax import numpy as jnp from jax.experimental import layout as jax_layout +from jax.experimental import multihost_utils from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn import ( From 9dba0fc35f484adf60dcfe4628d39b41a9d75ff1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 02:10:42 +0530 Subject: [PATCH 103/279] Debug --- examples/ml_perf/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 7457439e..949f2963 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -222,9 +222,10 @@ def generator(dataset, training=False): # validation_data=eval_generator, epochs=epochs, steps_per_epoch=steps_per_epoch, - callbacks=[MetricLogger()], + # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, + verbose=0, ) logger.info("Training finished") From 61d3b1aba757e959521e844ed027bdeb0ceffe62 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 02:20:54 +0530 Subject: [PATCH 104/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 949f2963..4c956240 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -225,7 +225,7 @@ def generator(dataset, training=False): # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - verbose=0, + # verbose=0, ) logger.info("Training finished") From a58a5096c42af75e3c714ad8002b5ae33214f5dc Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 02:27:17 +0530 Subject: [PATCH 105/279] Debug --- examples/ml_perf/main.py | 2 +- .../embedding/jax/distributed_embedding.py | 94 +++++++++---------- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 4c956240..b27dddec 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -222,7 +222,7 @@ def generator(dataset, training=False): # validation_data=eval_generator, epochs=epochs, steps_per_epoch=steps_per_epoch, - # callbacks=[MetricLogger()], + callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, # verbose=0, diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index bd69a0be..4bee8fc0 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -596,53 +596,53 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # aggregated_stats = jax.tree.map( + # lambda x: jnp.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 4a4abd38c7a822db6ea5676d176b4a8d90af2c9a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 02:36:11 +0530 Subject: [PATCH 106/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 55bcd8e2..3e6acb33 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 8448 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 19d82fdf64d6687a16f179322381c08b224909af Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 07:23:01 +0530 Subject: [PATCH 107/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 3e6acb33..55bcd8e2 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 8448 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 2b32be0082b3b69ae9996ba9c4e4f9c9cf5c90b6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 07:27:49 +0530 Subject: [PATCH 108/279] Debug --- .../embedding/jax/distributed_embedding.py | 94 +++++++++---------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 4bee8fc0..bd69a0be 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -596,53 +596,53 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # all_stats = multihost_utils.process_allgather(stats) - # aggregated_stats = jax.tree.map( - # lambda x: jnp.max(x, axis=0), all_stats - # ) - - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + all_stats = multihost_utils.process_allgather(stats) + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # Update configuration and repeat preprocessing if stats changed. + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 2c56aa7709ba8af0363eb860b223b89305dc09d6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 07:31:35 +0530 Subject: [PATCH 109/279] Debug --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index b27dddec..949f2963 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -222,10 +222,10 @@ def generator(dataset, training=False): # validation_data=eval_generator, epochs=epochs, steps_per_epoch=steps_per_epoch, - callbacks=[MetricLogger()], + # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - # verbose=0, + verbose=0, ) logger.info("Training finished") From 6815d8bf5fd4313fd3b87fab3b204deaffae2282 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 23 Oct 2025 10:21:23 +0530 Subject: [PATCH 110/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 949f2963..8c13388b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -232,7 +232,7 @@ def generator(dataset, training=False): if __name__ == "__main__": logging.basicConfig( - level=logging.INFO, + level=logging.DEBUG, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S", ) From 86299be4a84b265b3f5ca779078e794e0f66967b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 15:22:18 +0530 Subject: [PATCH 111/279] Debug --- .../embedding/jax/distributed_embedding.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index bd69a0be..068a1606 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -602,47 +602,47 @@ def _sparsecore_preprocess( # Aggregate stats across all processes/devices via pmax. all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # aggregated_stats = jax.tree.map( + # lambda x: jnp.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 1f98414ee94acf6aacf185499fd90f2c6a2942a1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 15:34:54 +0530 Subject: [PATCH 112/279] Debug --- .../embedding/jax/distributed_embedding.py | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 068a1606..3775d07b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -602,47 +602,47 @@ def _sparsecore_preprocess( # Aggregate stats across all processes/devices via pmax. all_stats = multihost_utils.process_allgather(stats) - # aggregated_stats = jax.tree.map( - # lambda x: jnp.max(x, axis=0), all_stats - # ) - - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # Update configuration and repeat preprocessing if stats changed. + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # Update configuration and repeat preprocessing if stats changed. # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 1b8645035e2dac8ec4547ad97b946e7eb8de2f8d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 15:43:58 +0530 Subject: [PATCH 113/279] Debug --- .../embedding/jax/distributed_embedding.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 3775d07b..1629e291 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -627,22 +627,22 @@ def _sparsecore_preprocess( for stack_name, spec in stacked_table_specs.items() ) - # Update configuration and repeat preprocessing if stats changed. - # if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # # Update configuration and repeat preprocessing if stats changed. + # # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From a504a2e2ab98eb8325d80a9aa7ed2ec234506b0f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 15:48:49 +0530 Subject: [PATCH 114/279] Debug --- .../layers/embedding/jax/distributed_embedding.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1629e291..2b9f1497 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -627,13 +627,13 @@ def _sparsecore_preprocess( for stack_name, spec in stacked_table_specs.items() ) - # # Update configuration and repeat preprocessing if stats changed. - # # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) + # Update configuration and repeat preprocessing if stats changed. + # if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) # # Re-execute preprocessing with consistent input statistics. # preprocessed, _ = embedding_utils.stack_and_shard_samples( From e7451f0efe50ff6901576c9f8537a2dca9fb7cd3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 16:04:11 +0530 Subject: [PATCH 115/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8c13388b..949f2963 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -232,7 +232,7 @@ def generator(dataset, training=False): if __name__ == "__main__": logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S", ) From 81410864c77eedd71fd1dc27c0016fec9c63ae7e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 16:20:34 +0530 Subject: [PATCH 116/279] Debug --- .../embedding/jax/distributed_embedding.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 2b9f1497..3775d07b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -635,14 +635,14 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 4a700dc373e63ba5a01f6fc911365d71869729f7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 16:23:52 +0530 Subject: [PATCH 117/279] Debug --- .../embedding/jax/distributed_embedding.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 3775d07b..1629e291 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -627,22 +627,22 @@ def _sparsecore_preprocess( for stack_name, spec in stacked_table_specs.items() ) - # Update configuration and repeat preprocessing if stats changed. - # if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # # Update configuration and repeat preprocessing if stats changed. + # # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 8ef3782cbd6351b9d099e8adf43cc56edc8b3b2e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 24 Oct 2025 16:27:35 +0530 Subject: [PATCH 118/279] Debug --- .../layers/embedding/jax/distributed_embedding.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1629e291..c6d28799 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -628,12 +628,12 @@ def _sparsecore_preprocess( ) # # Update configuration and repeat preprocessing if stats changed. - # # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) # # Re-execute preprocessing with consistent input statistics. # preprocessed, _ = embedding_utils.stack_and_shard_samples( From 17d5ec325e665896e081315db96de07a7492bdb9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 00:25:33 +0530 Subject: [PATCH 119/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 949f2963..4c956240 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -225,7 +225,7 @@ def generator(dataset, training=False): # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - verbose=0, + # verbose=0, ) logger.info("Training finished") From c2a6da2040461268370f37416530aaaa5bcf1f40 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 00:30:55 +0530 Subject: [PATCH 120/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 4c956240..2aa4fd5c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -150,7 +150,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, - num_steps=steps_per_epoch, + num_steps=steps_per_epoch + 20, training=True, ).create_dataset( process_id=distribution._process_id, From 9614833cd722e1469cbb0fd9cad3b2678a8e194e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 00:33:59 +0530 Subject: [PATCH 121/279] Debug --- examples/ml_perf/main.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2aa4fd5c..21cc0497 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -133,14 +133,18 @@ def main( # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. - if ds_cfg.val_file_pattern: - steps_per_epoch = training_cfg.eval_freq - epochs = training_cfg.num_steps // training_cfg.eval_freq - do_eval = True - else: - steps_per_epoch = training_cfg.num_steps - epochs = 1 - do_eval = False + # if ds_cfg.val_file_pattern: + # steps_per_epoch = training_cfg.eval_freq + # epochs = training_cfg.num_steps // training_cfg.eval_freq + # do_eval = True + # else: + # steps_per_epoch = training_cfg.num_steps + # epochs = 1 + # do_eval = False + + steps_per_epoch = training_cfg.num_steps + epochs = 1 + do_eval = False train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, From 90592c710d46f2479eca57ae8bb7a076860675e5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 00:36:13 +0530 Subject: [PATCH 122/279] Debug --- examples/ml_perf/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 21cc0497..4677872b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -182,7 +182,8 @@ def main( # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: train_ds = distribution.distribute_dataset(train_ds) - eval_ds = distribution.distribute_dataset(eval_ds) + if do_eval: + eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False def generator(dataset, training=False): From 68c320a5d503b2c38fa2ce450a45ef0dfec1ace8 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 07:40:31 +0530 Subject: [PATCH 123/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 5 +- examples/ml_perf/main.py | 20 ++-- .../embedding/jax/distributed_embedding.py | 94 +++++++++---------- 3 files changed, 56 insertions(+), 63 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 55bcd8e2..c92bfb79 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -12,10 +12,7 @@ "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-0000[0-3]-of-01024tfrecord" ) -dataset_config.val_file_pattern = ( - "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" - "train-0000[0-3]-of-01024tfrecord" -) +dataset_config.val_file_pattern = None # The path which we are reading from already has the batched dataset. dataset_config.file_batch_size = 4224 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 4677872b..6156c3f3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -133,18 +133,14 @@ def main( # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. - # if ds_cfg.val_file_pattern: - # steps_per_epoch = training_cfg.eval_freq - # epochs = training_cfg.num_steps // training_cfg.eval_freq - # do_eval = True - # else: - # steps_per_epoch = training_cfg.num_steps - # epochs = 1 - # do_eval = False - - steps_per_epoch = training_cfg.num_steps - epochs = 1 - do_eval = False + if ds_cfg.val_file_pattern: + steps_per_epoch = training_cfg.eval_freq + epochs = training_cfg.num_steps // training_cfg.eval_freq + do_eval = True + else: + steps_per_epoch = training_cfg.num_steps + epochs = 1 + do_eval = False train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index c6d28799..d57565c5 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -596,53 +596,53 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # aggregated_stats = jax.tree.map( + # lambda x: jnp.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 079499aed4e8065f3e67ce29080f198b3918bfd1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 07:54:57 +0530 Subject: [PATCH 124/279] Debug --- .../embedding/jax/distributed_embedding.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index d57565c5..4ffd67f6 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -596,15 +596,15 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # all_stats = multihost_utils.process_allgather(stats) - # aggregated_stats = jax.tree.map( - # lambda x: jnp.max(x, axis=0), all_stats - # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + all_stats = multihost_utils.process_allgather(stats) + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) # # Check if stats changed enough to warrant action. # stacked_table_specs = embedding.get_stacked_table_specs( From 1de870bcd1f9a860f261ba2a6f8bbba0a761d0c0 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 08:03:20 +0530 Subject: [PATCH 125/279] Debug --- .../embedding/jax/distributed_embedding.py | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 4ffd67f6..f274585a 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -606,43 +606,43 @@ def _sparsecore_preprocess( lambda x: jnp.max(x, axis=0), all_stats ) - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 7251303a3dbe06ae2187cadc18fcdef32d16b01c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 08:11:08 +0530 Subject: [PATCH 126/279] Debug --- .../embedding/jax/distributed_embedding.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index f274585a..3775d07b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -627,22 +627,22 @@ def _sparsecore_preprocess( for stack_name, spec in stacked_table_specs.items() ) - # # # Update configuration and repeat preprocessing if stats changed. + # Update configuration and repeat preprocessing if stats changed. # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 1fabedb2157aaa62b04cdce85bb9be805edef8d1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 08:15:16 +0530 Subject: [PATCH 127/279] Debug --- .../embedding/jax/distributed_embedding.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 3775d07b..09535a6b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -628,21 +628,21 @@ def _sparsecore_preprocess( ) # Update configuration and repeat preprocessing if stats changed. - # if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 8e881279e9465a9d2aef4e771aa80433833159aa Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sat, 25 Oct 2025 08:22:53 +0530 Subject: [PATCH 128/279] Debug --- .../embedding/jax/distributed_embedding.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 09535a6b..bd69a0be 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -635,14 +635,14 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( + self._config.feature_specs, + samples, + local_device_count, + global_device_count, + num_sc_per_device, + ) return {"inputs": preprocessed} From 2ac44973e107ff51f87144a831473037f24f3b7e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 02:41:24 +0530 Subject: [PATCH 129/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index bd69a0be..e85cb57c 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -601,10 +601,7 @@ def _sparsecore_preprocess( # underlying stacked tables specs in the feature specs. # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) + aggregated_stats = multihost_utils.process_allgather(stats) # Check if stats changed enough to warrant action. stacked_table_specs = embedding.get_stacked_table_specs( From 1816c914a68faa0a922ad6f812b8889476254774 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:00:21 +0530 Subject: [PATCH 130/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index e85cb57c..bd69a0be 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -601,7 +601,10 @@ def _sparsecore_preprocess( # underlying stacked tables specs in the feature specs. # Aggregate stats across all processes/devices via pmax. - aggregated_stats = multihost_utils.process_allgather(stats) + all_stats = multihost_utils.process_allgather(stats) + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) # Check if stats changed enough to warrant action. stacked_table_specs = embedding.get_stacked_table_specs( From 1f2a0f593de894a78b3b892bd4397f27f1fec8c1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:00:49 +0530 Subject: [PATCH 131/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index bd69a0be..8a79c881 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -602,9 +602,11 @@ def _sparsecore_preprocess( # Aggregate stats across all processes/devices via pmax. all_stats = multihost_utils.process_allgather(stats) + print("---> all_stats", all_stats) aggregated_stats = jax.tree.map( lambda x: jnp.max(x, axis=0), all_stats ) + print("---> aggregated_stats", aggregated_stats) # Check if stats changed enough to warrant action. stacked_table_specs = embedding.get_stacked_table_specs( From f74e4b53c292598ebc88ea6de3c86f8388be2c75 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:21:11 +0530 Subject: [PATCH 132/279] Debug --- .../embedding/jax/distributed_embedding.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 8a79c881..80cf33e1 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -193,6 +193,27 @@ def __call__( class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding): """JAX implementation of the TPU embedding layer.""" + def __init__(self, **kwargs: Any): + # Pull out `auto_stack_kwargs` from `kwargs`. + auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) + super().__init__(**kwargs) + + # For `max_ids_per_partition` and `max_unique_ids_per_partition`, JTE's + # `auto_stack_tables` expects callables. + + def _get_max_ids_per_partition(name: str, batch_size: int) -> int: + return auto_stack_kwargs["max_ids_per_partition"] + + def _get_max_unique_ids_per_partition(name: str, batch_size: int) -> int: + return auto_stack_kwargs["max_unique_ids_per_partition"] + + if "max_ids_per_partition" in auto_stack_kwargs: + auto_stack_kwargs["stack_to_max_ids_per_partition"] = _get_max_ids_per_partition + if "max_unique_ids_per_partition" in auto_stack_kwargs: + auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = _get_max_unique_ids_per_partition + + self._auto_stack_kwargs = auto_stack_kwargs + def _create_sparsecore_distribution( self, sparsecore_axis_name: str = "sparsecore" ) -> tuple[ @@ -402,7 +423,10 @@ def sparsecore_build( if isinstance(table_stacking, str): if table_stacking == "auto": jte_table_stacking.auto_stack_tables( - feature_specs, global_device_count, num_sc_per_device + feature_specs, + global_device_count, + num_sc_per_device, + **self._auto_stack_kwargs, ) else: raise ValueError( From 3e78c47155cd146ce5ed5e1fcd1274976cf55ef0 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:28:56 +0530 Subject: [PATCH 133/279] Debug --- examples/ml_perf/main.py | 4 ++++ examples/ml_perf/model.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 6156c3f3..492e99bf 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -115,6 +115,10 @@ def main( top_mlp_dims=model_cfg.top_mlp_dims, num_dcn_layers=model_cfg.num_dcn_layers, dcn_projection_dim=model_cfg.dcn_projection_dim, + auto_stack_kwargs={ + "max_ids_per_partition": max_ids_per_partition, + "max_unique_ids_per_partition": max_unique_ids_per_partition + }, seed=SEED, dtype="float32", name="dlrm_dcn_v2", diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 2e56576c..4e977ed6 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -49,6 +49,7 @@ def __init__( top_mlp_dims: list[int], num_dcn_layers: int, dcn_projection_dim: int, + auto_stack_kwargs: dict[str, Any], seed: int | keras.random.SeedGenerator | None = None, dtype: str | None = None, name: str | None = None, @@ -115,6 +116,7 @@ def __init__( self.embedding_layer = keras_rs.layers.DistributedEmbedding( feature_configs=large_emb_feature_configs, table_stacking="auto", + auto_stack_kwargs=auto_stack_kwargs, dtype=dtype, name="embedding_layer", ) @@ -167,6 +169,7 @@ def __init__( self.top_mlp_dims = top_mlp_dims self.num_dcn_layers = num_dcn_layers self.dcn_projection_dim = dcn_projection_dim + self.auto_stack_tables = auto_stack_kwargs def call(self, inputs: dict[str, Tensor]) -> Tensor: """Forward pass of the model. @@ -273,6 +276,7 @@ def get_config(self): "top_mlp_dims": self.top_mlp_dims, "num_dcn_layers": self.num_dcn_layers, "dcn_projection_dim": self.dcn_projection_dim, + "auto_stack_kwargs": auto_stack_kwargs, "seed": self.seed, } ) From 17a682564f8908fed344129139e817bb9489d644 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:31:15 +0530 Subject: [PATCH 134/279] Debug --- examples/ml_perf/main.py | 6 ++++-- examples/ml_perf/model.py | 4 ++-- .../layers/embedding/jax/distributed_embedding.py | 14 ++++++++++---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 492e99bf..5d4fb6c4 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -116,8 +116,10 @@ def main( num_dcn_layers=model_cfg.num_dcn_layers, dcn_projection_dim=model_cfg.dcn_projection_dim, auto_stack_kwargs={ - "max_ids_per_partition": max_ids_per_partition, - "max_unique_ids_per_partition": max_unique_ids_per_partition + "max_ids_per_partition": model_cfg.max_ids_per_partition, + "max_unique_ids_per_partition": ( + model_cfg.max_unique_ids_per_partition + ), }, seed=SEED, dtype="float32", diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 4e977ed6..81ef05e4 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -169,7 +169,7 @@ def __init__( self.top_mlp_dims = top_mlp_dims self.num_dcn_layers = num_dcn_layers self.dcn_projection_dim = dcn_projection_dim - self.auto_stack_tables = auto_stack_kwargs + self.auto_stack_kwargs = auto_stack_kwargs def call(self, inputs: dict[str, Tensor]) -> Tensor: """Forward pass of the model. @@ -276,7 +276,7 @@ def get_config(self): "top_mlp_dims": self.top_mlp_dims, "num_dcn_layers": self.num_dcn_layers, "dcn_projection_dim": self.dcn_projection_dim, - "auto_stack_kwargs": auto_stack_kwargs, + "auto_stack_kwargs": self.auto_stack_kwargs, "seed": self.seed, } ) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 80cf33e1..6a0f088d 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -204,14 +204,20 @@ def __init__(self, **kwargs: Any): def _get_max_ids_per_partition(name: str, batch_size: int) -> int: return auto_stack_kwargs["max_ids_per_partition"] - def _get_max_unique_ids_per_partition(name: str, batch_size: int) -> int: + def _get_max_unique_ids_per_partition( + name: str, batch_size: int + ) -> int: return auto_stack_kwargs["max_unique_ids_per_partition"] if "max_ids_per_partition" in auto_stack_kwargs: - auto_stack_kwargs["stack_to_max_ids_per_partition"] = _get_max_ids_per_partition + auto_stack_kwargs["stack_to_max_ids_per_partition"] = ( + _get_max_ids_per_partition + ) if "max_unique_ids_per_partition" in auto_stack_kwargs: - auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = _get_max_unique_ids_per_partition - + auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = ( + _get_max_unique_ids_per_partition + ) + self._auto_stack_kwargs = auto_stack_kwargs def _create_sparsecore_distribution( From ea0414565da638563d77ca4c1df4acd2f481d909 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:36:02 +0530 Subject: [PATCH 135/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 6a0f088d..4b318012 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -196,11 +196,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding): def __init__(self, **kwargs: Any): # Pull out `auto_stack_kwargs` from `kwargs`. auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) - super().__init__(**kwargs) # For `max_ids_per_partition` and `max_unique_ids_per_partition`, JTE's # `auto_stack_tables` expects callables. - def _get_max_ids_per_partition(name: str, batch_size: int) -> int: return auto_stack_kwargs["max_ids_per_partition"] @@ -219,6 +217,7 @@ def _get_max_unique_ids_per_partition( ) self._auto_stack_kwargs = auto_stack_kwargs + super().__init__(**kwargs) def _create_sparsecore_distribution( self, sparsecore_axis_name: str = "sparsecore" From 181449769cd4b069b9514b20eedf2ba714fff335 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:40:31 +0530 Subject: [PATCH 136/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 4b318012..3bba9043 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -208,10 +208,12 @@ def _get_max_unique_ids_per_partition( return auto_stack_kwargs["max_unique_ids_per_partition"] if "max_ids_per_partition" in auto_stack_kwargs: + auto_stack_kwargs.pop("max_ids_per_partition") auto_stack_kwargs["stack_to_max_ids_per_partition"] = ( _get_max_ids_per_partition ) if "max_unique_ids_per_partition" in auto_stack_kwargs: + auto_stack_kwargs.pop("max_unique_ids_per_partition") auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = ( _get_max_unique_ids_per_partition ) From 7549ce376671c528517aa5cc0cc8edf0a44e4e1f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:44:52 +0530 Subject: [PATCH 137/279] Debug --- .../layers/embedding/jax/distributed_embedding.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 3bba9043..d040ed59 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -197,23 +197,28 @@ def __init__(self, **kwargs: Any): # Pull out `auto_stack_kwargs` from `kwargs`. auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) + auto_stack_max_ids_per_partition = auto_stack_kwargs.get( + "max_ids_per_partition", None + ) + auto_stack_max_unique_ids_per_partition = auto_stack_kwargs.get( + "max_unique_ids_per_partition", None + ) + # For `max_ids_per_partition` and `max_unique_ids_per_partition`, JTE's # `auto_stack_tables` expects callables. def _get_max_ids_per_partition(name: str, batch_size: int) -> int: - return auto_stack_kwargs["max_ids_per_partition"] + return auto_stack_max_ids_per_partition def _get_max_unique_ids_per_partition( name: str, batch_size: int ) -> int: - return auto_stack_kwargs["max_unique_ids_per_partition"] + return auto_stack_max_ids_per_partition - if "max_ids_per_partition" in auto_stack_kwargs: - auto_stack_kwargs.pop("max_ids_per_partition") + if auto_stack_max_ids_per_partition is not None: auto_stack_kwargs["stack_to_max_ids_per_partition"] = ( _get_max_ids_per_partition ) if "max_unique_ids_per_partition" in auto_stack_kwargs: - auto_stack_kwargs.pop("max_unique_ids_per_partition") auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = ( _get_max_unique_ids_per_partition ) From 521886e381428bad45338c12d3811aad67de0620 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:48:20 +0530 Subject: [PATCH 138/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index d040ed59..8ee8a56e 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -197,10 +197,10 @@ def __init__(self, **kwargs: Any): # Pull out `auto_stack_kwargs` from `kwargs`. auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) - auto_stack_max_ids_per_partition = auto_stack_kwargs.get( + auto_stack_max_ids_per_partition = auto_stack_kwargs.pop( "max_ids_per_partition", None ) - auto_stack_max_unique_ids_per_partition = auto_stack_kwargs.get( + auto_stack_max_unique_ids_per_partition = auto_stack_kwargs.pop( "max_unique_ids_per_partition", None ) @@ -224,6 +224,7 @@ def _get_max_unique_ids_per_partition( ) self._auto_stack_kwargs = auto_stack_kwargs + print(f"-------> {self._auto_stack_kwargs=}") super().__init__(**kwargs) def _create_sparsecore_distribution( From 3d26c36398fa0530adf2c641b5e2913ee1950a04 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 03:55:33 +0530 Subject: [PATCH 139/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 8ee8a56e..04d2b73b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -212,7 +212,7 @@ def _get_max_ids_per_partition(name: str, batch_size: int) -> int: def _get_max_unique_ids_per_partition( name: str, batch_size: int ) -> int: - return auto_stack_max_ids_per_partition + return auto_stack_max_unique_ids_per_partition if auto_stack_max_ids_per_partition is not None: auto_stack_kwargs["stack_to_max_ids_per_partition"] = ( From 3c81da7c3c851bdb66278b731b58a69d138b6a1a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 07:54:42 +0530 Subject: [PATCH 140/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 04d2b73b..51cc4664 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -196,6 +196,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding): def __init__(self, **kwargs: Any): # Pull out `auto_stack_kwargs` from `kwargs`. auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) + print(f"{auto_stack_kwargs=}") auto_stack_max_ids_per_partition = auto_stack_kwargs.pop( "max_ids_per_partition", None From 5ec3caa80798988391967b063a5fcb5c0a5c90d7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 07:57:36 +0530 Subject: [PATCH 141/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 51cc4664..06ee852a 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -219,7 +219,7 @@ def _get_max_unique_ids_per_partition( auto_stack_kwargs["stack_to_max_ids_per_partition"] = ( _get_max_ids_per_partition ) - if "max_unique_ids_per_partition" in auto_stack_kwargs: + if auto_stack_max_unique_ids_per_partition is not None: auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = ( _get_max_unique_ids_per_partition ) From b6e7a4879e2b0b68367dd772a622931a8052f230 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 08:01:34 +0530 Subject: [PATCH 142/279] Debug --- .../embedding/jax/distributed_embedding.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 06ee852a..90c3cbac 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -196,7 +196,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding): def __init__(self, **kwargs: Any): # Pull out `auto_stack_kwargs` from `kwargs`. auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) - print(f"{auto_stack_kwargs=}") auto_stack_max_ids_per_partition = auto_stack_kwargs.pop( "max_ids_per_partition", None @@ -225,7 +224,6 @@ def _get_max_unique_ids_per_partition( ) self._auto_stack_kwargs = auto_stack_kwargs - print(f"-------> {self._auto_stack_kwargs=}") super().__init__(**kwargs) def _create_sparsecore_distribution( @@ -675,14 +673,14 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # Re-execute preprocessing with consistent input statistics. - preprocessed, _ = embedding_utils.stack_and_shard_samples( - self._config.feature_specs, - samples, - local_device_count, - global_device_count, - num_sc_per_device, - ) + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 0fa7497631f2d2c261db56d6d40d1485aba263a6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 08:07:36 +0530 Subject: [PATCH 143/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 90c3cbac..402f4b73 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -638,12 +638,9 @@ def _sparsecore_preprocess( # Aggregate stats across all processes/devices via pmax. all_stats = multihost_utils.process_allgather(stats) - print("---> all_stats", all_stats) aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats + lambda x: np.max(x, axis=0), all_stats ) - print("---> aggregated_stats", aggregated_stats) - # Check if stats changed enough to warrant action. stacked_table_specs = embedding.get_stacked_table_specs( self._config.feature_specs From c2d1a3bfcde66fd0e3e414c80b79dbb3a8e1dbb4 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 08:35:19 +0530 Subject: [PATCH 144/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 5d4fb6c4..d09d616c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -156,7 +156,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, - num_steps=steps_per_epoch + 20, + num_steps=steps_per_epoch + 2000, training=True, ).create_dataset( process_id=distribution._process_id, From dc9bfd87c24736e04a743e5358b75ddf80f84d7e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 13:15:04 +0530 Subject: [PATCH 145/279] Debug --- .../embedding/jax/distributed_embedding.py | 92 +++++++++---------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 402f4b73..1a7daac1 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -632,52 +632,52 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: np.max(x, axis=0), all_stats - ) - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # aggregated_stats = jax.tree.map( + # lambda x: np.max(x, axis=0), all_stats + # ) + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # # Re-execute preprocessing with consistent input statistics. + # # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # # self._config.feature_specs, + # # samples, + # # local_device_count, + # # global_device_count, + # # num_sc_per_device, + # # ) return {"inputs": preprocessed} From 67c92041367cac4347f55efbf43ace82ce03ec6c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 13:36:58 +0530 Subject: [PATCH 146/279] Debug --- examples/ml_perf/main.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d09d616c..65775f5b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -137,16 +137,21 @@ def main( # === Load dataset === logger.info("Loading dataset...") + # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. - if ds_cfg.val_file_pattern: - steps_per_epoch = training_cfg.eval_freq - epochs = training_cfg.num_steps // training_cfg.eval_freq - do_eval = True - else: - steps_per_epoch = training_cfg.num_steps - epochs = 1 - do_eval = False + # if ds_cfg.val_file_pattern: + # steps_per_epoch = training_cfg.eval_freq + # epochs = training_cfg.num_steps // training_cfg.eval_freq + # do_eval = True + # else: + # steps_per_epoch = training_cfg.num_steps + # epochs = 1 + # do_eval = False + + steps_per_epoch = training_cfg.num_steps + epochs = 2 + do_eval = False train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, From ecab799c14069f507106efafe6cd476c50a84dde Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 14:11:23 +0530 Subject: [PATCH 147/279] Debug --- examples/ml_perf/main.py | 1 + .../embedding/jax/distributed_embedding.py | 95 ++++++++++--------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 65775f5b..8621c2c3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -152,6 +152,7 @@ def main( steps_per_epoch = training_cfg.num_steps epochs = 2 do_eval = False + print(f"{steps_per_epoch=}, {epochs=}, {do_eval=}") train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1a7daac1..f51f10c6 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -632,52 +632,55 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # all_stats = multihost_utils.process_allgather(stats) - # aggregated_stats = jax.tree.map( - # lambda x: np.max(x, axis=0), all_stats - # ) - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # # Re-execute preprocessing with consistent input statistics. - # # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # # self._config.feature_specs, - # # samples, - # # local_device_count, - # # global_device_count, - # # num_sc_per_device, - # # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + all_stats = multihost_utils.process_allgather(stats) + # print("### all_stats", all_stats) + # aggregated_stats = all_stats + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # # Update configuration and repeat preprocessing if stats changed. + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From e0a59cb5451bd6537d9c5e8151d37d6038fa33bb Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 14:54:26 +0530 Subject: [PATCH 148/279] Debug --- examples/ml_perf/main.py | 23 ++--- .../embedding/jax/distributed_embedding.py | 98 +++++++++---------- 2 files changed, 59 insertions(+), 62 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8621c2c3..0dea4faf 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -140,19 +140,16 @@ def main( # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. - # if ds_cfg.val_file_pattern: - # steps_per_epoch = training_cfg.eval_freq - # epochs = training_cfg.num_steps // training_cfg.eval_freq - # do_eval = True - # else: - # steps_per_epoch = training_cfg.num_steps - # epochs = 1 - # do_eval = False - - steps_per_epoch = training_cfg.num_steps - epochs = 2 - do_eval = False - print(f"{steps_per_epoch=}, {epochs=}, {do_eval=}") + if ds_cfg.val_file_pattern: + steps_per_epoch = training_cfg.eval_freq + epochs = training_cfg.num_steps // training_cfg.eval_freq + do_eval = True + else: + steps_per_epoch = training_cfg.num_steps + epochs = 1 + do_eval = False + + logger.info(f"{steps_per_epoch=}, {epochs=}, {do_eval=}") train_ds = DataLoader( file_pattern=ds_cfg.file_pattern, diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index f51f10c6..b98ce6b4 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -632,55 +632,55 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - # print("### all_stats", all_stats) - # aggregated_stats = all_stats - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # # print("### all_stats", all_stats) + # # aggregated_stats = all_stats + # aggregated_stats = jax.tree.map( + # lambda x: jnp.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From dcb1745b00a2ad562b778c1c33f5be97ddb57119 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Sun, 26 Oct 2025 14:57:59 +0530 Subject: [PATCH 149/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- examples/ml_perf/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index c92bfb79..5f05751b 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 10 +training_config.num_steps = 1000 training_config.eval_freq = 5 training_config.num_eval_steps = 10 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 0dea4faf..f6625e39 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -159,7 +159,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, - num_steps=steps_per_epoch + 2000, + num_steps=steps_per_epoch, training=True, ).create_dataset( process_id=distribution._process_id, From 3d7f611d927942370c1bfb98d1135375690e0111 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 07:07:52 +0530 Subject: [PATCH 150/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 2 +- examples/ml_perf/main.py | 21 ++-- .../embedding/jax/distributed_embedding.py | 98 +++++++++---------- 3 files changed, 62 insertions(+), 59 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 5f05751b..c92bfb79 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 1000 +training_config.num_steps = 10 training_config.eval_freq = 5 training_config.num_eval_steps = 10 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index f6625e39..e4fa985f 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -140,14 +140,17 @@ def main( # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. - if ds_cfg.val_file_pattern: - steps_per_epoch = training_cfg.eval_freq - epochs = training_cfg.num_steps // training_cfg.eval_freq - do_eval = True - else: - steps_per_epoch = training_cfg.num_steps - epochs = 1 - do_eval = False + # if ds_cfg.val_file_pattern: + # steps_per_epoch = training_cfg.eval_freq + # epochs = training_cfg.num_steps // training_cfg.eval_freq + # do_eval = True + # else: + # steps_per_epoch = training_cfg.num_steps + # epochs = 1 + # do_eval = False + steps_per_epoch = training_cfg.num_steps + epochs = 2 + do_eval = False logger.info(f"{steps_per_epoch=}, {epochs=}, {do_eval=}") @@ -159,7 +162,7 @@ def main( large_emb_features=large_emb_features, small_emb_features=small_emb_features, label=ds_cfg.label, - num_steps=steps_per_epoch, + num_steps=steps_per_epoch + 20, training=True, ).create_dataset( process_id=distribution._process_id, diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index b98ce6b4..f51f10c6 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -632,55 +632,55 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # all_stats = multihost_utils.process_allgather(stats) - # # print("### all_stats", all_stats) - # # aggregated_stats = all_stats - # aggregated_stats = jax.tree.map( - # lambda x: jnp.max(x, axis=0), all_stats - # ) - - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + all_stats = multihost_utils.process_allgather(stats) + # print("### all_stats", all_stats) + # aggregated_stats = all_stats + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # # Update configuration and repeat preprocessing if stats changed. + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 12d354e0bb925c59024ad65455884d09ec7e08b4 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 07:14:38 +0530 Subject: [PATCH 151/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index f51f10c6..b0d47cb6 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -641,7 +641,7 @@ def _sparsecore_preprocess( # print("### all_stats", all_stats) # aggregated_stats = all_stats aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats + lambda x: np.max(x, axis=0), all_stats ) # Check if stats changed enough to warrant action. From 6416fee37e84658d4ed33b61b1cda3b7cfe86317 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 09:18:09 +0530 Subject: [PATCH 152/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index c92bfb79..d09be6b7 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,157 +24,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "cat_14", + "new_name": "14", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "cat_15", + "new_name": "15", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_16", + "new_name": "16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_17", + "new_name": "17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "cat_18", + "new_name": "18", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "feature_list_length": 1, - "new_name": "cat_19", + "new_name": "19", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "feature_list_length": 1, - "new_name": "cat_20", + "new_name": "20", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "feature_list_length": 1, - "new_name": "cat_21", + "new_name": "21", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "feature_list_length": 1, - "new_name": "cat_22", + "new_name": "22", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "cat_23", + "new_name": "23", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "feature_list_length": 3, - "new_name": "cat_24", + "new_name": "24", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "feature_list_length": 8, - "new_name": "cat_25", + "new_name": "25", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "feature_list_length": 1, - "new_name": "cat_26", + "new_name": "26", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "feature_list_length": 6, - "new_name": "cat_27", + "new_name": "27", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "feature_list_length": 9, - "new_name": "cat_28", + "new_name": "28", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "feature_list_length": 5, - "new_name": "cat_29", + "new_name": "29", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "feature_list_length": 1, - "new_name": "cat_30", + "new_name": "30", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "feature_list_length": 1, - "new_name": "cat_31", + "new_name": "31", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "feature_list_length": 1, - "new_name": "cat_32", + "new_name": "32", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "feature_list_length": 12, - "new_name": "cat_33", + "new_name": "33", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "feature_list_length": 100, - "new_name": "cat_34", + "new_name": "34", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "feature_list_length": 27, - "new_name": "cat_35", + "new_name": "35", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "feature_list_length": 10, - "new_name": "cat_36", + "new_name": "36", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "cat_37", + "new_name": "37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "cat_38", + "new_name": "38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "cat_39", + "new_name": "39", }, ] From 7c7d55d364b24e21db4b8d4bd300df5a9b72e155 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 09:21:19 +0530 Subject: [PATCH 153/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index e4fa985f..c8e13695 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -220,7 +220,7 @@ def generator(dataset, training=False): logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) logger.debug( "Small embedding inputs:%s", - first_batch[0]["small_emb_inputs"]["cat_39_id"], + first_batch[0]["small_emb_inputs"]["39_id"], ) logger.debug( "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] From 357d540b67de1796d437eb4f1e279dfd3346138d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 09:24:55 +0530 Subject: [PATCH 154/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index d09be6b7..24032a9e 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,157 +24,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "14", + "new_name": "0", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "15", + "new_name": "1", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "16", + "new_name": "2", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "17", + "new_name": "3", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "18", + "new_name": "4", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "feature_list_length": 1, - "new_name": "19", + "new_name": "5", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "feature_list_length": 1, - "new_name": "20", + "new_name": "6", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "feature_list_length": 1, - "new_name": "21", + "new_name": "7", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "feature_list_length": 1, - "new_name": "22", + "new_name": "8", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "23", + "new_name": "9", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "feature_list_length": 3, - "new_name": "24", + "new_name": "10", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "feature_list_length": 8, - "new_name": "25", + "new_name": "11", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "feature_list_length": 1, - "new_name": "26", + "new_name": "12", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "feature_list_length": 6, - "new_name": "27", + "new_name": "13", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "feature_list_length": 9, - "new_name": "28", + "new_name": "14", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "feature_list_length": 5, - "new_name": "29", + "new_name": "15", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "feature_list_length": 1, - "new_name": "30", + "new_name": "16", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "feature_list_length": 1, - "new_name": "31", + "new_name": "17", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "feature_list_length": 1, - "new_name": "32", + "new_name": "18", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "feature_list_length": 12, - "new_name": "33", + "new_name": "19", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "feature_list_length": 100, - "new_name": "34", + "new_name": "20", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "feature_list_length": 27, - "new_name": "35", + "new_name": "21", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "feature_list_length": 10, - "new_name": "36", + "new_name": "22", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "37", + "new_name": "23", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "38", + "new_name": "24", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "39", + "new_name": "25", }, ] From 1a0326569be0ea64be6d8f3e51f976a0b60dd365 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 09:25:48 +0530 Subject: [PATCH 155/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index c8e13695..d29e3f44 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -70,7 +70,7 @@ def main( feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( - name=f"{feature_name}_table", + name=feature_name, vocabulary_size=vocabulary_size, embedding_dim=model_cfg.embedding_dim, # TODO(abheesht): Verify. From d71e7e02cdbbe59ca33aa3415dcfd40394fb7aab Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 09:28:27 +0530 Subject: [PATCH 156/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d29e3f44..daebf25b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -220,7 +220,7 @@ def generator(dataset, training=False): logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) logger.debug( "Small embedding inputs:%s", - first_batch[0]["small_emb_inputs"]["39_id"], + first_batch[0]["small_emb_inputs"]["25_id"], ) logger.debug( "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] From eddff344e1fd17d5b14fe81b8cbb5a4e7fb7f2a2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 10:50:58 +0530 Subject: [PATCH 157/279] Debug --- examples/ml_perf/configs/v6e_16.py | 52 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 7837b7d9..4166ab3d 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -17,157 +17,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "cat_14", + "new_name": "0", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "cat_15", + "new_name": "1", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_16", + "new_name": "2", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_17", + "new_name": "3", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "cat_18", + "new_name": "4", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "feature_list_length": 1, - "new_name": "cat_19", + "new_name": "5", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "feature_list_length": 1, - "new_name": "cat_20", + "new_name": "6", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "feature_list_length": 1, - "new_name": "cat_21", + "new_name": "7", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "feature_list_length": 1, - "new_name": "cat_22", + "new_name": "8", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "cat_23", + "new_name": "9", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "feature_list_length": 3, - "new_name": "cat_24", + "new_name": "10", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "feature_list_length": 8, - "new_name": "cat_25", + "new_name": "11", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "feature_list_length": 1, - "new_name": "cat_26", + "new_name": "12", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "feature_list_length": 6, - "new_name": "cat_27", + "new_name": "13", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "feature_list_length": 9, - "new_name": "cat_28", + "new_name": "14", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "feature_list_length": 5, - "new_name": "cat_29", + "new_name": "15", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "feature_list_length": 1, - "new_name": "cat_30", + "new_name": "16", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "feature_list_length": 1, - "new_name": "cat_31", + "new_name": "17", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "feature_list_length": 1, - "new_name": "cat_32", + "new_name": "18", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "feature_list_length": 12, - "new_name": "cat_33", + "new_name": "19", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "feature_list_length": 100, - "new_name": "cat_34", + "new_name": "20", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "feature_list_length": 27, - "new_name": "cat_35", + "new_name": "21", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "feature_list_length": 10, - "new_name": "cat_36", + "new_name": "22", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "cat_37", + "new_name": "23", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "cat_38", + "new_name": "24", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "cat_39", + "new_name": "25", }, ] From 187c53488eba57ab2dcbb014159ee0f949b0317e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 10:58:45 +0530 Subject: [PATCH 158/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index daebf25b..3c7667aa 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -70,7 +70,7 @@ def main( feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( - name=feature_name, + name=f"{feature_name}t, vocabulary_size=vocabulary_size, embedding_dim=model_cfg.embedding_dim, # TODO(abheesht): Verify. From 1c13e1380b50086248b67f8ee2ecb65552aecb03 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 10:59:11 +0530 Subject: [PATCH 159/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 3c7667aa..001d0267 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -70,7 +70,7 @@ def main( feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( - name=f"{feature_name}t, + name=f"{feature_name}t", vocabulary_size=vocabulary_size, embedding_dim=model_cfg.embedding_dim, # TODO(abheesht): Verify. From a26ef291a76e21ec3914ad592656a3a69b031e6f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 11:03:19 +0530 Subject: [PATCH 160/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 001d0267..838e532f 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -70,7 +70,7 @@ def main( feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( - name=f"{feature_name}t", + name=f"{feature_name}_table", vocabulary_size=vocabulary_size, embedding_dim=model_cfg.embedding_dim, # TODO(abheesht): Verify. From dffcb968d8d0d8cf57b3acc40752caf07168f262 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 11:07:42 +0530 Subject: [PATCH 161/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 24032a9e..9c40c658 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,157 +24,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "0", + "new_name": "cat_0", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "1", + "new_name": "cat_1", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "2", + "new_name": "cat_2", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "3", + "new_name": "cat_3", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "4", + "new_name": "cat_4", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "feature_list_length": 1, - "new_name": "5", + "new_name": "cat_5", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "feature_list_length": 1, - "new_name": "6", + "new_name": "cat_6", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "feature_list_length": 1, - "new_name": "7", + "new_name": "cat_7", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "feature_list_length": 1, - "new_name": "8", + "new_name": "cat_8", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "9", + "new_name": "cat_9", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "feature_list_length": 3, - "new_name": "10", + "new_name": "cat_10", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "feature_list_length": 8, - "new_name": "11", + "new_name": "cat_11", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "feature_list_length": 1, - "new_name": "12", + "new_name": "cat_12", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "feature_list_length": 6, - "new_name": "13", + "new_name": "cat_13", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "feature_list_length": 9, - "new_name": "14", + "new_name": "cat_14", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "feature_list_length": 5, - "new_name": "15", + "new_name": "cat_15", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "feature_list_length": 1, - "new_name": "16", + "new_name": "cat_16", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "feature_list_length": 1, - "new_name": "17", + "new_name": "cat_17", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "feature_list_length": 1, - "new_name": "18", + "new_name": "cat_18", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "feature_list_length": 12, - "new_name": "19", + "new_name": "cat_19", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "feature_list_length": 100, - "new_name": "20", + "new_name": "cat_20", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "feature_list_length": 27, - "new_name": "21", + "new_name": "cat_21", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "feature_list_length": 10, - "new_name": "22", + "new_name": "cat_22", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "23", + "new_name": "cat_23", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "24", + "new_name": "cat_24", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "25", + "new_name": "cat_25", }, ] From 1995e16ed5f46081b7cb841b3e506f0e218bbe0f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 11:10:33 +0530 Subject: [PATCH 162/279] Debug --- examples/ml_perf/main.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 838e532f..f442792f 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -215,17 +215,17 @@ def generator(dataset, training=False): train_generator = generator(train_ds, training=True) if do_eval: eval_generator = generator(eval_ds, training=False) - logger.debug("Inspecting one batch of data...") - for first_batch in train_generator: - logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) - logger.debug( - "Small embedding inputs:%s", - first_batch[0]["small_emb_inputs"]["25_id"], - ) - logger.debug( - "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] - ) - break + # logger.debug("Inspecting one batch of data...") + # for first_batch in train_generator: + # logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) + # logger.debug( + # "Small embedding inputs:%s", + # first_batch[0]["small_emb_inputs"]["25_id"], + # ) + # logger.debug( + # "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] + # ) + # break logger.info("Successfully preprocessed one batch of data") # === Training === From 2e42b53683d9c0b31cdd93eff0753f4decc7b6e2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 11:13:45 +0530 Subject: [PATCH 163/279] Debug --- .../embedding/jax/distributed_embedding.py | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index b0d47cb6..46d2c0fb 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -632,55 +632,55 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - # print("### all_stats", all_stats) - # aggregated_stats = all_stats - aggregated_stats = jax.tree.map( - lambda x: np.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # # print("### all_stats", all_stats) + # # aggregated_stats = all_stats + # aggregated_stats = jax.tree.map( + # lambda x: np.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # # Re-execute preprocessing with consistent input statistics. + # # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # # self._config.feature_specs, + # # samples, + # # local_device_count, + # # global_device_count, + # # num_sc_per_device, + # # ) return {"inputs": preprocessed} From 9e8d15381a5a53e795e7757d3992150190f9d357 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:09:47 +0530 Subject: [PATCH 164/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 300 +++++++++--------- 1 file changed, 150 insertions(+), 150 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 9c40c658..1bfe1f5f 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -26,156 +26,156 @@ "feature_list_length": 3, "new_name": "cat_0", }, - { - "name": "categorical-feature-15", - "vocabulary_size": 39060, - "feature_list_length": 2, - "new_name": "cat_1", - }, - { - "name": "categorical-feature-16", - "vocabulary_size": 17295, - "feature_list_length": 1, - "new_name": "cat_2", - }, - { - "name": "categorical-feature-17", - "vocabulary_size": 7424, - "feature_list_length": 2, - "new_name": "cat_3", - }, - { - "name": "categorical-feature-18", - "vocabulary_size": 20265, - "feature_list_length": 6, - "new_name": "cat_4", - }, - { - "name": "categorical-feature-19", - "vocabulary_size": 3, - "feature_list_length": 1, - "new_name": "cat_5", - }, - { - "name": "categorical-feature-20", - "vocabulary_size": 7122, - "feature_list_length": 1, - "new_name": "cat_6", - }, - { - "name": "categorical-feature-21", - "vocabulary_size": 1543, - "feature_list_length": 1, - "new_name": "cat_7", - }, - { - "name": "categorical-feature-22", - "vocabulary_size": 63, - "feature_list_length": 1, - "new_name": "cat_8", - }, - { - "name": "categorical-feature-23", - "vocabulary_size": 40000000, - "feature_list_length": 7, - "new_name": "cat_9", - }, - { - "name": "categorical-feature-24", - "vocabulary_size": 3067956, - "feature_list_length": 3, - "new_name": "cat_10", - }, - { - "name": "categorical-feature-25", - "vocabulary_size": 405282, - "feature_list_length": 8, - "new_name": "cat_11", - }, - { - "name": "categorical-feature-26", - "vocabulary_size": 10, - "feature_list_length": 1, - "new_name": "cat_12", - }, - { - "name": "categorical-feature-27", - "vocabulary_size": 2209, - "feature_list_length": 6, - "new_name": "cat_13", - }, - { - "name": "categorical-feature-28", - "vocabulary_size": 11938, - "feature_list_length": 9, - "new_name": "cat_14", - }, - { - "name": "categorical-feature-29", - "vocabulary_size": 155, - "feature_list_length": 5, - "new_name": "cat_15", - }, - { - "name": "categorical-feature-30", - "vocabulary_size": 4, - "feature_list_length": 1, - "new_name": "cat_16", - }, - { - "name": "categorical-feature-31", - "vocabulary_size": 976, - "feature_list_length": 1, - "new_name": "cat_17", - }, - { - "name": "categorical-feature-32", - "vocabulary_size": 14, - "feature_list_length": 1, - "new_name": "cat_18", - }, - { - "name": "categorical-feature-33", - "vocabulary_size": 40000000, - "feature_list_length": 12, - "new_name": "cat_19", - }, - { - "name": "categorical-feature-34", - "vocabulary_size": 40000000, - "feature_list_length": 100, - "new_name": "cat_20", - }, - { - "name": "categorical-feature-35", - "vocabulary_size": 40000000, - "feature_list_length": 27, - "new_name": "cat_21", - }, - { - "name": "categorical-feature-36", - "vocabulary_size": 590152, - "feature_list_length": 10, - "new_name": "cat_22", - }, - { - "name": "categorical-feature-37", - "vocabulary_size": 12973, - "feature_list_length": 3, - "new_name": "cat_23", - }, - { - "name": "categorical-feature-38", - "vocabulary_size": 108, - "feature_list_length": 1, - "new_name": "cat_24", - }, - { - "name": "categorical-feature-39", - "vocabulary_size": 36, - "feature_list_length": 1, - "new_name": "cat_25", - }, + # { + # "name": "categorical-feature-15", + # "vocabulary_size": 39060, + # "feature_list_length": 2, + # "new_name": "cat_1", + # }, + # { + # "name": "categorical-feature-16", + # "vocabulary_size": 17295, + # "feature_list_length": 1, + # "new_name": "cat_2", + # }, + # { + # "name": "categorical-feature-17", + # "vocabulary_size": 7424, + # "feature_list_length": 2, + # "new_name": "cat_3", + # }, + # { + # "name": "categorical-feature-18", + # "vocabulary_size": 20265, + # "feature_list_length": 6, + # "new_name": "cat_4", + # }, + # { + # "name": "categorical-feature-19", + # "vocabulary_size": 3, + # "feature_list_length": 1, + # "new_name": "cat_5", + # }, + # { + # "name": "categorical-feature-20", + # "vocabulary_size": 7122, + # "feature_list_length": 1, + # "new_name": "cat_6", + # }, + # { + # "name": "categorical-feature-21", + # "vocabulary_size": 1543, + # "feature_list_length": 1, + # "new_name": "cat_7", + # }, + # { + # "name": "categorical-feature-22", + # "vocabulary_size": 63, + # "feature_list_length": 1, + # "new_name": "cat_8", + # }, + # { + # "name": "categorical-feature-23", + # "vocabulary_size": 40000000, + # "feature_list_length": 7, + # "new_name": "cat_9", + # }, + # { + # "name": "categorical-feature-24", + # "vocabulary_size": 3067956, + # "feature_list_length": 3, + # "new_name": "cat_10", + # }, + # { + # "name": "categorical-feature-25", + # "vocabulary_size": 405282, + # "feature_list_length": 8, + # "new_name": "cat_11", + # }, + # { + # "name": "categorical-feature-26", + # "vocabulary_size": 10, + # "feature_list_length": 1, + # "new_name": "cat_12", + # }, + # { + # "name": "categorical-feature-27", + # "vocabulary_size": 2209, + # "feature_list_length": 6, + # "new_name": "cat_13", + # }, + # { + # "name": "categorical-feature-28", + # "vocabulary_size": 11938, + # "feature_list_length": 9, + # "new_name": "cat_14", + # }, + # { + # "name": "categorical-feature-29", + # "vocabulary_size": 155, + # "feature_list_length": 5, + # "new_name": "cat_15", + # }, + # { + # "name": "categorical-feature-30", + # "vocabulary_size": 4, + # "feature_list_length": 1, + # "new_name": "cat_16", + # }, + # { + # "name": "categorical-feature-31", + # "vocabulary_size": 976, + # "feature_list_length": 1, + # "new_name": "cat_17", + # }, + # { + # "name": "categorical-feature-32", + # "vocabulary_size": 14, + # "feature_list_length": 1, + # "new_name": "cat_18", + # }, + # { + # "name": "categorical-feature-33", + # "vocabulary_size": 40000000, + # "feature_list_length": 12, + # "new_name": "cat_19", + # }, + # { + # "name": "categorical-feature-34", + # "vocabulary_size": 40000000, + # "feature_list_length": 100, + # "new_name": "cat_20", + # }, + # { + # "name": "categorical-feature-35", + # "vocabulary_size": 40000000, + # "feature_list_length": 27, + # "new_name": "cat_21", + # }, + # { + # "name": "categorical-feature-36", + # "vocabulary_size": 590152, + # "feature_list_length": 10, + # "new_name": "cat_22", + # }, + # { + # "name": "categorical-feature-37", + # "vocabulary_size": 12973, + # "feature_list_length": 3, + # "new_name": "cat_23", + # }, + # { + # "name": "categorical-feature-38", + # "vocabulary_size": 108, + # "feature_list_length": 1, + # "new_name": "cat_24", + # }, + # { + # "name": "categorical-feature-39", + # "vocabulary_size": 36, + # "feature_list_length": 1, + # "new_name": "cat_25", + # }, ] # === Model === From e9c5ba3cbf285ea5e2786402276d7d14186862ac Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:12:10 +0530 Subject: [PATCH 165/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 1bfe1f5f..e6cbf069 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -38,12 +38,12 @@ # "feature_list_length": 1, # "new_name": "cat_2", # }, - # { - # "name": "categorical-feature-17", - # "vocabulary_size": 7424, - # "feature_list_length": 2, - # "new_name": "cat_3", - # }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_3", + }, # { # "name": "categorical-feature-18", # "vocabulary_size": 20265, From a34448779af94efff23191a958030011fde2cb43 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:42:05 +0530 Subject: [PATCH 166/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index e6cbf069..f2f95c15 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -74,12 +74,12 @@ # "feature_list_length": 1, # "new_name": "cat_8", # }, - # { - # "name": "categorical-feature-23", - # "vocabulary_size": 40000000, - # "feature_list_length": 7, - # "new_name": "cat_9", - # }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "cat_9", + }, # { # "name": "categorical-feature-24", # "vocabulary_size": 3067956, From 05d422eb27c749f03400d54c0151625a5e98ed34 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:45:18 +0530 Subject: [PATCH 167/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index f2f95c15..6da95273 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -26,12 +26,12 @@ "feature_list_length": 3, "new_name": "cat_0", }, - # { - # "name": "categorical-feature-15", - # "vocabulary_size": 39060, - # "feature_list_length": 2, - # "new_name": "cat_1", - # }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_1", + }, # { # "name": "categorical-feature-16", # "vocabulary_size": 17295, From d35876d6d65c246bac70365f71ce3e79d500c940 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:46:04 +0530 Subject: [PATCH 168/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 120 +++++++++--------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 6da95273..7571d429 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -92,48 +92,48 @@ # "feature_list_length": 8, # "new_name": "cat_11", # }, - # { - # "name": "categorical-feature-26", - # "vocabulary_size": 10, - # "feature_list_length": 1, - # "new_name": "cat_12", - # }, - # { - # "name": "categorical-feature-27", - # "vocabulary_size": 2209, - # "feature_list_length": 6, - # "new_name": "cat_13", - # }, - # { - # "name": "categorical-feature-28", - # "vocabulary_size": 11938, - # "feature_list_length": 9, - # "new_name": "cat_14", - # }, - # { - # "name": "categorical-feature-29", - # "vocabulary_size": 155, - # "feature_list_length": 5, - # "new_name": "cat_15", - # }, - # { - # "name": "categorical-feature-30", - # "vocabulary_size": 4, - # "feature_list_length": 1, - # "new_name": "cat_16", - # }, - # { - # "name": "categorical-feature-31", - # "vocabulary_size": 976, - # "feature_list_length": 1, - # "new_name": "cat_17", - # }, - # { - # "name": "categorical-feature-32", - # "vocabulary_size": 14, - # "feature_list_length": 1, - # "new_name": "cat_18", - # }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_12", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_13", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_16", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_18", + }, # { # "name": "categorical-feature-33", # "vocabulary_size": 40000000, @@ -158,24 +158,24 @@ # "feature_list_length": 10, # "new_name": "cat_22", # }, - # { - # "name": "categorical-feature-37", - # "vocabulary_size": 12973, - # "feature_list_length": 3, - # "new_name": "cat_23", - # }, - # { - # "name": "categorical-feature-38", - # "vocabulary_size": 108, - # "feature_list_length": 1, - # "new_name": "cat_24", - # }, - # { - # "name": "categorical-feature-39", - # "vocabulary_size": 36, - # "feature_list_length": 1, - # "new_name": "cat_25", - # }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "cat_25", + }, ] # === Model === From 7e41ec77e187c5e69e8f0c82d9df5e47617e0a30 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:50:56 +0530 Subject: [PATCH 169/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7571d429..ad2ce86d 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -26,12 +26,12 @@ "feature_list_length": 3, "new_name": "cat_0", }, - { - "name": "categorical-feature-15", - "vocabulary_size": 39060, - "feature_list_length": 2, - "new_name": "cat_1", - }, + # { + # "name": "categorical-feature-15", + # "vocabulary_size": 39060, + # "feature_list_length": 2, + # "new_name": "cat_1", + # }, # { # "name": "categorical-feature-16", # "vocabulary_size": 17295, @@ -92,48 +92,48 @@ # "feature_list_length": 8, # "new_name": "cat_11", # }, - { - "name": "categorical-feature-26", - "vocabulary_size": 10, - "feature_list_length": 1, - "new_name": "cat_12", - }, - { - "name": "categorical-feature-27", - "vocabulary_size": 2209, - "feature_list_length": 6, - "new_name": "cat_13", - }, - { - "name": "categorical-feature-28", - "vocabulary_size": 11938, - "feature_list_length": 9, - "new_name": "cat_14", - }, - { - "name": "categorical-feature-29", - "vocabulary_size": 155, - "feature_list_length": 5, - "new_name": "cat_15", - }, - { - "name": "categorical-feature-30", - "vocabulary_size": 4, - "feature_list_length": 1, - "new_name": "cat_16", - }, - { - "name": "categorical-feature-31", - "vocabulary_size": 976, - "feature_list_length": 1, - "new_name": "cat_17", - }, - { - "name": "categorical-feature-32", - "vocabulary_size": 14, - "feature_list_length": 1, - "new_name": "cat_18", - }, + # { + # "name": "categorical-feature-26", + # "vocabulary_size": 10, + # "feature_list_length": 1, + # "new_name": "cat_12", + # }, + # { + # "name": "categorical-feature-27", + # "vocabulary_size": 2209, + # "feature_list_length": 6, + # "new_name": "cat_13", + # }, + # { + # "name": "categorical-feature-28", + # "vocabulary_size": 11938, + # "feature_list_length": 9, + # "new_name": "cat_14", + # }, + # { + # "name": "categorical-feature-29", + # "vocabulary_size": 155, + # "feature_list_length": 5, + # "new_name": "cat_15", + # }, + # { + # "name": "categorical-feature-30", + # "vocabulary_size": 4, + # "feature_list_length": 1, + # "new_name": "cat_16", + # }, + # { + # "name": "categorical-feature-31", + # "vocabulary_size": 976, + # "feature_list_length": 1, + # "new_name": "cat_17", + # }, + # { + # "name": "categorical-feature-32", + # "vocabulary_size": 14, + # "feature_list_length": 1, + # "new_name": "cat_18", + # }, # { # "name": "categorical-feature-33", # "vocabulary_size": 40000000, From 16019c34d9e725c4f451f43724b17aaae3ae543b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:53:38 +0530 Subject: [PATCH 170/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 258 +++++++++--------- 1 file changed, 129 insertions(+), 129 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index ad2ce86d..c92bfb79 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,157 +24,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "cat_0", - }, - # { - # "name": "categorical-feature-15", - # "vocabulary_size": 39060, - # "feature_list_length": 2, - # "new_name": "cat_1", - # }, - # { - # "name": "categorical-feature-16", - # "vocabulary_size": 17295, - # "feature_list_length": 1, - # "new_name": "cat_2", - # }, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "feature_list_length": 1, + "new_name": "cat_16", + }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_3", - }, - # { - # "name": "categorical-feature-18", - # "vocabulary_size": 20265, - # "feature_list_length": 6, - # "new_name": "cat_4", - # }, - # { - # "name": "categorical-feature-19", - # "vocabulary_size": 3, - # "feature_list_length": 1, - # "new_name": "cat_5", - # }, - # { - # "name": "categorical-feature-20", - # "vocabulary_size": 7122, - # "feature_list_length": 1, - # "new_name": "cat_6", - # }, - # { - # "name": "categorical-feature-21", - # "vocabulary_size": 1543, - # "feature_list_length": 1, - # "new_name": "cat_7", - # }, - # { - # "name": "categorical-feature-22", - # "vocabulary_size": 63, - # "feature_list_length": 1, - # "new_name": "cat_8", - # }, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "cat_18", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "cat_19", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "cat_20", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "cat_21", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "cat_22", + }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "cat_9", - }, - # { - # "name": "categorical-feature-24", - # "vocabulary_size": 3067956, - # "feature_list_length": 3, - # "new_name": "cat_10", - # }, - # { - # "name": "categorical-feature-25", - # "vocabulary_size": 405282, - # "feature_list_length": 8, - # "new_name": "cat_11", - # }, - # { - # "name": "categorical-feature-26", - # "vocabulary_size": 10, - # "feature_list_length": 1, - # "new_name": "cat_12", - # }, - # { - # "name": "categorical-feature-27", - # "vocabulary_size": 2209, - # "feature_list_length": 6, - # "new_name": "cat_13", - # }, - # { - # "name": "categorical-feature-28", - # "vocabulary_size": 11938, - # "feature_list_length": 9, - # "new_name": "cat_14", - # }, - # { - # "name": "categorical-feature-29", - # "vocabulary_size": 155, - # "feature_list_length": 5, - # "new_name": "cat_15", - # }, - # { - # "name": "categorical-feature-30", - # "vocabulary_size": 4, - # "feature_list_length": 1, - # "new_name": "cat_16", - # }, - # { - # "name": "categorical-feature-31", - # "vocabulary_size": 976, - # "feature_list_length": 1, - # "new_name": "cat_17", - # }, - # { - # "name": "categorical-feature-32", - # "vocabulary_size": 14, - # "feature_list_length": 1, - # "new_name": "cat_18", - # }, - # { - # "name": "categorical-feature-33", - # "vocabulary_size": 40000000, - # "feature_list_length": 12, - # "new_name": "cat_19", - # }, - # { - # "name": "categorical-feature-34", - # "vocabulary_size": 40000000, - # "feature_list_length": 100, - # "new_name": "cat_20", - # }, - # { - # "name": "categorical-feature-35", - # "vocabulary_size": 40000000, - # "feature_list_length": 27, - # "new_name": "cat_21", - # }, - # { - # "name": "categorical-feature-36", - # "vocabulary_size": 590152, - # "feature_list_length": 10, - # "new_name": "cat_22", - # }, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "cat_25", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_26", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_27", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_28", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_29", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_30", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_31", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_32", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "cat_33", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "cat_34", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "cat_35", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "cat_36", + }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "cat_23", + "new_name": "cat_37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "cat_24", + "new_name": "cat_38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "cat_25", + "new_name": "cat_39", }, ] From 799573321efe87791278d7b4fe3f640e7c5a1dc6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 12:56:19 +0530 Subject: [PATCH 171/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index c92bfb79..faadb62b 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,31 +24,31 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "cat_14", + "new_name": "cat_0", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "cat_15", + "new_name": "cat_1", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_16", + "new_name": "cat_2", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_17", + "new_name": "cat_3", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "cat_18", + "new_name": "cat_4", }, { "name": "categorical-feature-19", From 3732bfdea9a563b2ac5d73ac0f311aeb2e306be6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:01:34 +0530 Subject: [PATCH 172/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index faadb62b..a1934b5f 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -36,19 +36,19 @@ "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_2", + "new_name": "cat_16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_3", + "new_name": "cat_17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "cat_4", + "new_name": "cat_18", }, { "name": "categorical-feature-19", From 992014390e744cf4550199a661c009c964526e24 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:04:00 +0530 Subject: [PATCH 173/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index a1934b5f..ae23d112 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -36,7 +36,7 @@ "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_16", + "new_name": "cat_2", }, { "name": "categorical-feature-17", From aadb2b19bce1e190cd9ab980b38e545f5c530d8f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:11:12 +0530 Subject: [PATCH 174/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 276 +++++++++--------- 1 file changed, 138 insertions(+), 138 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index ae23d112..fb0352ee 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -38,144 +38,144 @@ "feature_list_length": 1, "new_name": "cat_2", }, - { - "name": "categorical-feature-17", - "vocabulary_size": 7424, - "feature_list_length": 2, - "new_name": "cat_17", - }, - { - "name": "categorical-feature-18", - "vocabulary_size": 20265, - "feature_list_length": 6, - "new_name": "cat_18", - }, - { - "name": "categorical-feature-19", - "vocabulary_size": 3, - "feature_list_length": 1, - "new_name": "cat_19", - }, - { - "name": "categorical-feature-20", - "vocabulary_size": 7122, - "feature_list_length": 1, - "new_name": "cat_20", - }, - { - "name": "categorical-feature-21", - "vocabulary_size": 1543, - "feature_list_length": 1, - "new_name": "cat_21", - }, - { - "name": "categorical-feature-22", - "vocabulary_size": 63, - "feature_list_length": 1, - "new_name": "cat_22", - }, - { - "name": "categorical-feature-23", - "vocabulary_size": 40000000, - "feature_list_length": 7, - "new_name": "cat_23", - }, - { - "name": "categorical-feature-24", - "vocabulary_size": 3067956, - "feature_list_length": 3, - "new_name": "cat_24", - }, - { - "name": "categorical-feature-25", - "vocabulary_size": 405282, - "feature_list_length": 8, - "new_name": "cat_25", - }, - { - "name": "categorical-feature-26", - "vocabulary_size": 10, - "feature_list_length": 1, - "new_name": "cat_26", - }, - { - "name": "categorical-feature-27", - "vocabulary_size": 2209, - "feature_list_length": 6, - "new_name": "cat_27", - }, - { - "name": "categorical-feature-28", - "vocabulary_size": 11938, - "feature_list_length": 9, - "new_name": "cat_28", - }, - { - "name": "categorical-feature-29", - "vocabulary_size": 155, - "feature_list_length": 5, - "new_name": "cat_29", - }, - { - "name": "categorical-feature-30", - "vocabulary_size": 4, - "feature_list_length": 1, - "new_name": "cat_30", - }, - { - "name": "categorical-feature-31", - "vocabulary_size": 976, - "feature_list_length": 1, - "new_name": "cat_31", - }, - { - "name": "categorical-feature-32", - "vocabulary_size": 14, - "feature_list_length": 1, - "new_name": "cat_32", - }, - { - "name": "categorical-feature-33", - "vocabulary_size": 40000000, - "feature_list_length": 12, - "new_name": "cat_33", - }, - { - "name": "categorical-feature-34", - "vocabulary_size": 40000000, - "feature_list_length": 100, - "new_name": "cat_34", - }, - { - "name": "categorical-feature-35", - "vocabulary_size": 40000000, - "feature_list_length": 27, - "new_name": "cat_35", - }, - { - "name": "categorical-feature-36", - "vocabulary_size": 590152, - "feature_list_length": 10, - "new_name": "cat_36", - }, - { - "name": "categorical-feature-37", - "vocabulary_size": 12973, - "feature_list_length": 3, - "new_name": "cat_37", - }, - { - "name": "categorical-feature-38", - "vocabulary_size": 108, - "feature_list_length": 1, - "new_name": "cat_38", - }, - { - "name": "categorical-feature-39", - "vocabulary_size": 36, - "feature_list_length": 1, - "new_name": "cat_39", - }, + # { + # "name": "categorical-feature-17", + # "vocabulary_size": 7424, + # "feature_list_length": 2, + # "new_name": "cat_17", + # }, + # { + # "name": "categorical-feature-18", + # "vocabulary_size": 20265, + # "feature_list_length": 6, + # "new_name": "cat_18", + # }, + # { + # "name": "categorical-feature-19", + # "vocabulary_size": 3, + # "feature_list_length": 1, + # "new_name": "cat_19", + # }, + # { + # "name": "categorical-feature-20", + # "vocabulary_size": 7122, + # "feature_list_length": 1, + # "new_name": "cat_20", + # }, + # { + # "name": "categorical-feature-21", + # "vocabulary_size": 1543, + # "feature_list_length": 1, + # "new_name": "cat_21", + # }, + # { + # "name": "categorical-feature-22", + # "vocabulary_size": 63, + # "feature_list_length": 1, + # "new_name": "cat_22", + # }, + # { + # "name": "categorical-feature-23", + # "vocabulary_size": 40000000, + # "feature_list_length": 7, + # "new_name": "cat_23", + # }, + # { + # "name": "categorical-feature-24", + # "vocabulary_size": 3067956, + # "feature_list_length": 3, + # "new_name": "cat_24", + # }, + # { + # "name": "categorical-feature-25", + # "vocabulary_size": 405282, + # "feature_list_length": 8, + # "new_name": "cat_25", + # }, + # { + # "name": "categorical-feature-26", + # "vocabulary_size": 10, + # "feature_list_length": 1, + # "new_name": "cat_26", + # }, + # { + # "name": "categorical-feature-27", + # "vocabulary_size": 2209, + # "feature_list_length": 6, + # "new_name": "cat_27", + # }, + # { + # "name": "categorical-feature-28", + # "vocabulary_size": 11938, + # "feature_list_length": 9, + # "new_name": "cat_28", + # }, + # { + # "name": "categorical-feature-29", + # "vocabulary_size": 155, + # "feature_list_length": 5, + # "new_name": "cat_29", + # }, + # { + # "name": "categorical-feature-30", + # "vocabulary_size": 4, + # "feature_list_length": 1, + # "new_name": "cat_30", + # }, + # { + # "name": "categorical-feature-31", + # "vocabulary_size": 976, + # "feature_list_length": 1, + # "new_name": "cat_31", + # }, + # { + # "name": "categorical-feature-32", + # "vocabulary_size": 14, + # "feature_list_length": 1, + # "new_name": "cat_32", + # }, + # { + # "name": "categorical-feature-33", + # "vocabulary_size": 40000000, + # "feature_list_length": 12, + # "new_name": "cat_33", + # }, + # { + # "name": "categorical-feature-34", + # "vocabulary_size": 40000000, + # "feature_list_length": 100, + # "new_name": "cat_34", + # }, + # { + # "name": "categorical-feature-35", + # "vocabulary_size": 40000000, + # "feature_list_length": 27, + # "new_name": "cat_35", + # }, + # { + # "name": "categorical-feature-36", + # "vocabulary_size": 590152, + # "feature_list_length": 10, + # "new_name": "cat_36", + # }, + # { + # "name": "categorical-feature-37", + # "vocabulary_size": 12973, + # "feature_list_length": 3, + # "new_name": "cat_37", + # }, + # { + # "name": "categorical-feature-38", + # "vocabulary_size": 108, + # "feature_list_length": 1, + # "new_name": "cat_38", + # }, + # { + # "name": "categorical-feature-39", + # "vocabulary_size": 36, + # "feature_list_length": 1, + # "new_name": "cat_39", + # }, ] # === Model === From fc4a26aed213b5f484eca2b4732a13b6415ba45f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:11:29 +0530 Subject: [PATCH 175/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index fb0352ee..0ec5efd6 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -38,12 +38,12 @@ "feature_list_length": 1, "new_name": "cat_2", }, - # { - # "name": "categorical-feature-17", - # "vocabulary_size": 7424, - # "feature_list_length": 2, - # "new_name": "cat_17", - # }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_17", + }, # { # "name": "categorical-feature-18", # "vocabulary_size": 20265, From 059752f588b1bd4f284833cd42a640f14f8c59c7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:17:35 +0530 Subject: [PATCH 176/279] Debug --- examples/ml_perf/main.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index f442792f..09a5a726 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -215,6 +215,20 @@ def generator(dataset, training=False): train_generator = generator(train_ds, training=True) if do_eval: eval_generator = generator(eval_ds, training=False) + for first_batch in train_generator: + logger.info("Dense inputs:%s", first_batch[0]["dense_input"].shape) + for k in first_batch[0]["small_emb_inputs"]: + logger.info( + "Small embedding inputs:%s %s", + k, first_batch[0]["small_emb_inputs"][k].shape, + ) + for k in first_batch[0]["large_emb_inputs"]: + logger.info( + "Large embedding inputs:%s %s", + k, first_batch[0]["large_emb_inputs"][k].shape, + ) + break + # logger.debug("Inspecting one batch of data...") # for first_batch in train_generator: # logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) From e5929f5dc133841ef368d185cb433db83fc382d9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:20:34 +0530 Subject: [PATCH 177/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 09a5a726..2196f741 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -225,7 +225,7 @@ def generator(dataset, training=False): for k in first_batch[0]["large_emb_inputs"]: logger.info( "Large embedding inputs:%s %s", - k, first_batch[0]["large_emb_inputs"][k].shape, + k, first_batch[0]["large_emb_inputs"][k], ) break From 29a47ffc362f7590737968848fdc56518b1a3b98 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:23:49 +0530 Subject: [PATCH 178/279] Debug --- examples/ml_perf/main.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2196f741..27b09a7c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -194,6 +194,20 @@ def main( eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False + for first_batch in train_ds.take(1): + logger.info("Dense inputs:%s", first_batch[0]["dense_input"].shape) + for k in first_batch[0]["small_emb_inputs"]: + logger.info( + "Small embedding inputs:%s %s", + k, first_batch[0]["small_emb_inputs"][k].shape, + ) + for k in first_batch[0]["large_emb_inputs"]: + logger.info( + "Large embedding inputs:%s %s", + k, first_batch[0]["large_emb_inputs"][k].shape, + ) + break + def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses large embedding features. @@ -215,19 +229,19 @@ def generator(dataset, training=False): train_generator = generator(train_ds, training=True) if do_eval: eval_generator = generator(eval_ds, training=False) - for first_batch in train_generator: - logger.info("Dense inputs:%s", first_batch[0]["dense_input"].shape) - for k in first_batch[0]["small_emb_inputs"]: - logger.info( - "Small embedding inputs:%s %s", - k, first_batch[0]["small_emb_inputs"][k].shape, - ) - for k in first_batch[0]["large_emb_inputs"]: - logger.info( - "Large embedding inputs:%s %s", - k, first_batch[0]["large_emb_inputs"][k], - ) - break + # for first_batch in train_generator: + # logger.info("Dense inputs:%s", first_batch[0]["dense_input"].shape) + # for k in first_batch[0]["small_emb_inputs"]: + # logger.info( + # "Small embedding inputs:%s %s", + # k, first_batch[0]["small_emb_inputs"][k].shape, + # ) + # for k in first_batch[0]["large_emb_inputs"]: + # logger.info( + # "Large embedding inputs:%s %s", + # k, first_batch[0]["large_emb_inputs"][k], + # ) + # break # logger.debug("Inspecting one batch of data...") # for first_batch in train_generator: From dbd896ef00530cc4f31fc9527671b3dbc6df9d17 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:29:24 +0530 Subject: [PATCH 179/279] Debug --- examples/ml_perf/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 27b09a7c..d90f3d55 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -64,10 +64,12 @@ def main( logger.debug("Small Embedding Features: %s", small_emb_features) feature_configs = {} + print("--->", large_emb_features) for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] feature_list_length = large_emb_feature["feature_list_length"] + print("bruh:", feature_name, vocabulary_size, feature_list_length) table_config = keras_rs.layers.TableConfig( name=f"{feature_name}_table", From c04d6f0aee0da0574b5264d7d2d02ba118df42b5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:34:09 +0530 Subject: [PATCH 180/279] Debug --- examples/ml_perf/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d90f3d55..392a1195 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -103,6 +103,7 @@ def main( model_cfg.embedding_dim, ), ) + print("bruh:", feature_configs) # === Instantiate model === # We instantiate the model first, because we need to preprocess large From 6af186d32d2a758a164da49b8c13b2aa5477786d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:38:22 +0530 Subject: [PATCH 181/279] Debug --- examples/ml_perf/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 392a1195..13e17238 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -15,6 +15,9 @@ # Set random seed. SEED = 1337 +import jax +jax.config.update("jax_debug_nans", True) + logger = logging.getLogger(__name__) keras.utils.set_random_seed(SEED) From 1a59ca6e110766f48c67bf5cee4be606d363e539 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 13:54:04 +0530 Subject: [PATCH 182/279] Debug --- examples/ml_perf/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 81ef05e4..275850ad 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -208,6 +208,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Predictions outputs = self.top_mlp(x) + jax.debug.print("--------> {}", outputs) return outputs def _get_mlp_layers( From 23e08b3efbe6057c17ccba22fa5e26cf7ae73643 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:04:11 +0530 Subject: [PATCH 183/279] Debug --- examples/ml_perf/main.py | 4 ++-- examples/ml_perf/model.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 13e17238..ce1a884b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -15,8 +15,8 @@ # Set random seed. SEED = 1337 -import jax -jax.config.update("jax_debug_nans", True) +# import jax +# jax.config.update("jax_debug_nans", True) logger = logging.getLogger(__name__) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 275850ad..777519e3 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -5,6 +5,7 @@ from keras import ops import keras_rs +import jax Tensor: TypeAlias = Any From ab75ffbc234b606ea41c4f4a3ec20f30b9ef4a42 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:08:08 +0530 Subject: [PATCH 184/279] Debug --- examples/ml_perf/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 777519e3..7704a0d3 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -185,8 +185,9 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Embed features. dense_output = self.bottom_mlp(dense_input) - # jax.debug.print("dense_ouput {}", dense_output.shape) + jax.debug.print("dense_output {}", dense_output.shape) large_embeddings = self.embedding_layer(large_emb_inputs) + jax.debug.print("large_embeddings {}", large_embeddings) small_embeddings = None if self.small_emb_features: small_embeddings = [] @@ -200,6 +201,8 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_embeddings = ops.concatenate(small_embeddings, axis=-1) + jax.debug.print("small_emebddings {}", small_embeddings) + # Interaction to_concatenate = [dense_output, *large_embeddings.values()] if small_embeddings is not None: From 8b26c090ed19e403b962b66a609b46d52344d6eb Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:11:09 +0530 Subject: [PATCH 185/279] Debug --- examples/ml_perf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 7704a0d3..a99a2369 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -185,7 +185,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Embed features. dense_output = self.bottom_mlp(dense_input) - jax.debug.print("dense_output {}", dense_output.shape) + jax.debug.print("dense_output {}", dense_output) large_embeddings = self.embedding_layer(large_emb_inputs) jax.debug.print("large_embeddings {}", large_embeddings) small_embeddings = None From 41cae788247b9760648acdff3f401d0a16f19be3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:12:30 +0530 Subject: [PATCH 186/279] Debug --- examples/ml_perf/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index a99a2369..3952c6dc 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -185,9 +185,9 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Embed features. dense_output = self.bottom_mlp(dense_input) - jax.debug.print("dense_output {}", dense_output) + jax.debug.print("dense_output {}", jnp.any(jnp.isnan(dense_output))) large_embeddings = self.embedding_layer(large_emb_inputs) - jax.debug.print("large_embeddings {}", large_embeddings) + jax.debug.print("large_embeddings {}", jnp.any(jnp.isnan(large_embeddings))) small_embeddings = None if self.small_emb_features: small_embeddings = [] @@ -201,7 +201,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_embeddings = ops.concatenate(small_embeddings, axis=-1) - jax.debug.print("small_emebddings {}", small_embeddings) + jax.debug.print("small_embeddings {}", jnp.any(jnp.isnan(small_embeddings))) # Interaction to_concatenate = [dense_output, *large_embeddings.values()] @@ -212,7 +212,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Predictions outputs = self.top_mlp(x) - jax.debug.print("--------> {}", outputs) + jax.debug.print("outputs --------> {}", jnp.any(jnp.isnan(outputs))) return outputs def _get_mlp_layers( From 04c93b0ed161b467f46291dbe2b09761430ebb02 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:35:52 +0530 Subject: [PATCH 187/279] Debug --- examples/ml_perf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 3952c6dc..be0abbe1 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -187,7 +187,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: dense_output = self.bottom_mlp(dense_input) jax.debug.print("dense_output {}", jnp.any(jnp.isnan(dense_output))) large_embeddings = self.embedding_layer(large_emb_inputs) - jax.debug.print("large_embeddings {}", jnp.any(jnp.isnan(large_embeddings))) + jax.debug.print("large_embeddings {}", [jnp.any(jnp.isnan(large_emb)) for large_emb in large_embeddings.values()]) small_embeddings = None if self.small_emb_features: small_embeddings = [] From dc3251e9ea129d20725c74e132dc06c1978b2f07 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:37:57 +0530 Subject: [PATCH 188/279] Debug --- examples/ml_perf/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index be0abbe1..afffb089 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -6,6 +6,7 @@ import keras_rs import jax +import jax.numpy as jnp Tensor: TypeAlias = Any From d199d4f78e0bb67685d2d022e1463db2d97d051a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:53:00 +0530 Subject: [PATCH 189/279] Debug --- examples/ml_perf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index afffb089..af596bd7 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -139,7 +139,7 @@ def __init__( ) for i, small_emb_feature in enumerate(small_emb_features) ] - logging.debug( + logging.info( "Initialised small embedding layers: %s", self.small_embedding_layers, ) From a920936c47b44272fa129c07ce3a026f37348185 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 14:57:42 +0530 Subject: [PATCH 190/279] Debug --- examples/ml_perf/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index af596bd7..70922fc6 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -128,6 +128,7 @@ def __init__( # Embedding layers for small embedding tables self.small_embedding_layers = None if small_emb_features: + print(f"{small_emb_features=}") self.small_embedding_layers = [ keras.layers.Embedding( input_dim=small_emb_feature["vocabulary_size"], From 369c4d2f1583192779633875c5af50df2721d1f0 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:06:22 +0530 Subject: [PATCH 191/279] Debug --- examples/ml_perf/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 70922fc6..bc9dd09a 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -198,7 +198,9 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_emb_inputs.values(), self.small_embedding_layers ): embedding = embedding_layer(small_emb_input) + jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) embedding = ops.sum(embedding, axis=-2) + jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) small_embeddings.append(embedding) small_embeddings = ops.concatenate(small_embeddings, axis=-1) From 6599a1023013803c9bbc271e179080b8593db52b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:10:03 +0530 Subject: [PATCH 192/279] Debug --- examples/ml_perf/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index bc9dd09a..53eba5ed 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -197,6 +197,8 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: for small_emb_input, embedding_layer in zip( small_emb_inputs.values(), self.small_embedding_layers ): + jax.debug.print("embedding layer: {}", embedding_layer) + jax.debug.print("small_embeddings input {}", jnp.any(jnp.isnan(small_emb_input))) embedding = embedding_layer(small_emb_input) jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) embedding = ops.sum(embedding, axis=-2) From 2c7c4e0b3fa1e9f239c04c1a3651bd01329a31ee Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:13:03 +0530 Subject: [PATCH 193/279] Debug --- examples/ml_perf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 53eba5ed..e4457d08 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -197,7 +197,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: for small_emb_input, embedding_layer in zip( small_emb_inputs.values(), self.small_embedding_layers ): - jax.debug.print("embedding layer: {}", embedding_layer) + # jax.debug.print("embedding layer: {}", embedding_layer) jax.debug.print("small_embeddings input {}", jnp.any(jnp.isnan(small_emb_input))) embedding = embedding_layer(small_emb_input) jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) From 9e094163300a2cd7b1d089183f43b7725ee7339a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:22:26 +0530 Subject: [PATCH 194/279] Debug --- examples/ml_perf/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index e4457d08..871577df 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -200,6 +200,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # jax.debug.print("embedding layer: {}", embedding_layer) jax.debug.print("small_embeddings input {}", jnp.any(jnp.isnan(small_emb_input))) embedding = embedding_layer(small_emb_input) + jax.debug.print("small_embeddings input max {}", jnp.max(small_emb_input)) jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) embedding = ops.sum(embedding, axis=-2) jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) From a626141690ecb2a6fc0f4bc3ee87fda13446ed8c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:31:52 +0530 Subject: [PATCH 195/279] Debug --- examples/ml_perf/model.py | 56 ++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 871577df..54d90699 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -128,18 +128,21 @@ def __init__( # Embedding layers for small embedding tables self.small_embedding_layers = None if small_emb_features: - print(f"{small_emb_features=}") - self.small_embedding_layers = [ - keras.layers.Embedding( - input_dim=small_emb_feature["vocabulary_size"], - output_dim=embedding_dim, - embeddings_initializer=keras.initializers.LecunNormal( - seed=self.seed, - ), - name=f"small_embedding_layer_{i}", + self.small_embedding_layers = {} + for small_emb_feature in small_emb_features: + name = small_emb_feature["name"] + new_name = small_emb_feature["new_name"] + vocabulary_size = small_emb_feature["vocabulary_size"] + self.small_embedding_layers[new_name] = ( + keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=embedding_dim, + embeddings_initializer=keras.initializers.LecunNormal( + seed=self.seed, + ), + name=f"small_embedding_layer_{new_name}", + ) ) - for i, small_emb_feature in enumerate(small_emb_features) - ] logging.info( "Initialised small embedding layers: %s", self.small_embedding_layers, @@ -194,21 +197,32 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: if self.small_emb_features: small_embeddings = [] small_emb_inputs = inputs["small_emb_inputs"] - for small_emb_input, embedding_layer in zip( - small_emb_inputs.values(), self.small_embedding_layers - ): - # jax.debug.print("embedding layer: {}", embedding_layer) - jax.debug.print("small_embeddings input {}", jnp.any(jnp.isnan(small_emb_input))) + for small_emb_feature in small_emb_inputs.keys(): + small_emb_input = small_emb_inputs[small_emb_feature] + embedding_layer = self.small_embedding_layers[small_emb_feature] + embedding = embedding_layer(small_emb_input) - jax.debug.print("small_embeddings input max {}", jnp.max(small_emb_input)) - jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) embedding = ops.sum(embedding, axis=-2) - jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) - small_embeddings.append(embedding) + small_embeddings.append(embedding) + small_embeddings = ops.concatenate(small_embeddings, axis=-1) - jax.debug.print("small_embeddings {}", jnp.any(jnp.isnan(small_embeddings))) + # for small_emb_input, embedding_layer in zip( + # small_emb_inputs.values(), self.small_embedding_layers + # ): + # # jax.debug.print("embedding layer: {}", embedding_layer) + # jax.debug.print("small_embeddings input {}", jnp.any(jnp.isnan(small_emb_input))) + # embedding = embedding_layer(small_emb_input) + # jax.debug.print("small_embeddings input max {}", jnp.max(small_emb_input)) + # jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) + # embedding = ops.sum(embedding, axis=-2) + # jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) + # small_embeddings.append(embedding) + + # small_embeddings = ops.concatenate(small_embeddings, axis=-1) + + # jax.debug.print("small_embeddings {}", jnp.any(jnp.isnan(small_embeddings))) # Interaction to_concatenate = [dense_output, *large_embeddings.values()] From 89aa78f55d3af974ab15a68a1f6e32e53b8d0e0b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:35:42 +0530 Subject: [PATCH 196/279] Debug --- examples/ml_perf/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 54d90699..6e736cff 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -199,14 +199,14 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_emb_inputs = inputs["small_emb_inputs"] for small_emb_feature in small_emb_inputs.keys(): small_emb_input = small_emb_inputs[small_emb_feature] - embedding_layer = self.small_embedding_layers[small_emb_feature] + # embedding_layer = self.small_embedding_layers[small_emb_feature] - embedding = embedding_layer(small_emb_input) - embedding = ops.sum(embedding, axis=-2) + # embedding = embedding_layer(small_emb_input) + # embedding = ops.sum(embedding, axis=-2) - small_embeddings.append(embedding) + # small_embeddings.append(embedding) - small_embeddings = ops.concatenate(small_embeddings, axis=-1) + # small_embeddings = ops.concatenate(small_embeddings, axis=-1) # for small_emb_input, embedding_layer in zip( # small_emb_inputs.values(), self.small_embedding_layers From d06329ae45971096a7d4ddf22ff36883716f5b98 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:38:47 +0530 Subject: [PATCH 197/279] Debug --- examples/ml_perf/model.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 6e736cff..1cfd27ca 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -133,7 +133,7 @@ def __init__( name = small_emb_feature["name"] new_name = small_emb_feature["new_name"] vocabulary_size = small_emb_feature["vocabulary_size"] - self.small_embedding_layers[new_name] = ( + self.small_embedding_layers[f"{new_name}_id"] = ( keras.layers.Embedding( input_dim=vocabulary_size, output_dim=embedding_dim, @@ -199,30 +199,14 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_emb_inputs = inputs["small_emb_inputs"] for small_emb_feature in small_emb_inputs.keys(): small_emb_input = small_emb_inputs[small_emb_feature] - # embedding_layer = self.small_embedding_layers[small_emb_feature] + embedding_layer = self.small_embedding_layers[small_emb_feature] - # embedding = embedding_layer(small_emb_input) - # embedding = ops.sum(embedding, axis=-2) + embedding = embedding_layer(small_emb_input) + embedding = ops.sum(embedding, axis=-2) - # small_embeddings.append(embedding) + small_embeddings.append(embedding) - # small_embeddings = ops.concatenate(small_embeddings, axis=-1) - - # for small_emb_input, embedding_layer in zip( - # small_emb_inputs.values(), self.small_embedding_layers - # ): - # # jax.debug.print("embedding layer: {}", embedding_layer) - # jax.debug.print("small_embeddings input {}", jnp.any(jnp.isnan(small_emb_input))) - # embedding = embedding_layer(small_emb_input) - # jax.debug.print("small_embeddings input max {}", jnp.max(small_emb_input)) - # jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) - # embedding = ops.sum(embedding, axis=-2) - # jax.debug.print("small_embeddings embedding {}", jnp.any(jnp.isnan(embedding))) - # small_embeddings.append(embedding) - - # small_embeddings = ops.concatenate(small_embeddings, axis=-1) - - # jax.debug.print("small_embeddings {}", jnp.any(jnp.isnan(small_embeddings))) + small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction to_concatenate = [dense_output, *large_embeddings.values()] From b6b841323932a6dbf6b020304cfb7707d1321370 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:43:47 +0530 Subject: [PATCH 198/279] Debug --- examples/ml_perf/main.py | 53 +++++++++------------------------------ examples/ml_perf/model.py | 8 +----- 2 files changed, 13 insertions(+), 48 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index ce1a884b..2d9de218 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -67,12 +67,10 @@ def main( logger.debug("Small Embedding Features: %s", small_emb_features) feature_configs = {} - print("--->", large_emb_features) for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] feature_list_length = large_emb_feature["feature_list_length"] - print("bruh:", feature_name, vocabulary_size, feature_list_length) table_config = keras_rs.layers.TableConfig( name=f"{feature_name}_table", @@ -106,7 +104,6 @@ def main( model_cfg.embedding_dim, ), ) - print("bruh:", feature_configs) # === Instantiate model === # We instantiate the model first, because we need to preprocess large @@ -200,19 +197,6 @@ def main( eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False - for first_batch in train_ds.take(1): - logger.info("Dense inputs:%s", first_batch[0]["dense_input"].shape) - for k in first_batch[0]["small_emb_inputs"]: - logger.info( - "Small embedding inputs:%s %s", - k, first_batch[0]["small_emb_inputs"][k].shape, - ) - for k in first_batch[0]["large_emb_inputs"]: - logger.info( - "Large embedding inputs:%s %s", - k, first_batch[0]["large_emb_inputs"][k].shape, - ) - break def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses @@ -235,31 +219,18 @@ def generator(dataset, training=False): train_generator = generator(train_ds, training=True) if do_eval: eval_generator = generator(eval_ds, training=False) - # for first_batch in train_generator: - # logger.info("Dense inputs:%s", first_batch[0]["dense_input"].shape) - # for k in first_batch[0]["small_emb_inputs"]: - # logger.info( - # "Small embedding inputs:%s %s", - # k, first_batch[0]["small_emb_inputs"][k].shape, - # ) - # for k in first_batch[0]["large_emb_inputs"]: - # logger.info( - # "Large embedding inputs:%s %s", - # k, first_batch[0]["large_emb_inputs"][k], - # ) - # break - - # logger.debug("Inspecting one batch of data...") - # for first_batch in train_generator: - # logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) - # logger.debug( - # "Small embedding inputs:%s", - # first_batch[0]["small_emb_inputs"]["25_id"], - # ) - # logger.debug( - # "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] - # ) - # break + + logger.debug("Inspecting one batch of data...") + for first_batch in train_generator: + logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) + logger.debug( + "Small embedding inputs:%s", + first_batch[0]["small_emb_inputs"]["25_id"], + ) + logger.debug( + "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] + ) + break logger.info("Successfully preprocessed one batch of data") # === Training === diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 1cfd27ca..6620de3a 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -143,7 +143,7 @@ def __init__( name=f"small_embedding_layer_{new_name}", ) ) - logging.info( + logging.debug( "Initialised small embedding layers: %s", self.small_embedding_layers, ) @@ -190,9 +190,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Embed features. dense_output = self.bottom_mlp(dense_input) - jax.debug.print("dense_output {}", jnp.any(jnp.isnan(dense_output))) large_embeddings = self.embedding_layer(large_emb_inputs) - jax.debug.print("large_embeddings {}", [jnp.any(jnp.isnan(large_emb)) for large_emb in large_embeddings.values()]) small_embeddings = None if self.small_emb_features: small_embeddings = [] @@ -200,12 +198,9 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: for small_emb_feature in small_emb_inputs.keys(): small_emb_input = small_emb_inputs[small_emb_feature] embedding_layer = self.small_embedding_layers[small_emb_feature] - embedding = embedding_layer(small_emb_input) embedding = ops.sum(embedding, axis=-2) - small_embeddings.append(embedding) - small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction @@ -217,7 +212,6 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Predictions outputs = self.top_mlp(x) - jax.debug.print("outputs --------> {}", jnp.any(jnp.isnan(outputs))) return outputs def _get_mlp_layers( From 7b10edf7a3ce488fcc5819938f8f8609fcaa5b05 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:44:50 +0530 Subject: [PATCH 199/279] Fix small emb order --- examples/ml_perf/main.py | 5 ----- examples/ml_perf/model.py | 2 -- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 1 - 3 files changed, 8 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2d9de218..5d94ddcc 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -15,9 +15,6 @@ # Set random seed. SEED = 1337 -# import jax -# jax.config.update("jax_debug_nans", True) - logger = logging.getLogger(__name__) keras.utils.set_random_seed(SEED) @@ -140,7 +137,6 @@ def main( # === Load dataset === logger.info("Loading dataset...") - # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. # if ds_cfg.val_file_pattern: @@ -197,7 +193,6 @@ def main( eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False - def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses large embedding features. diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 6620de3a..9861c472 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -5,8 +5,6 @@ from keras import ops import keras_rs -import jax -import jax.numpy as jnp Tensor: TypeAlias = Any diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 46d2c0fb..b13baad2 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -9,7 +9,6 @@ import numpy as np from jax import numpy as jnp from jax.experimental import layout as jax_layout -from jax.experimental import multihost_utils from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn import ( From f7b8407848e9a61c383150e68d7086fb4ec2e3df Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:55:46 +0530 Subject: [PATCH 200/279] Fix small emb order --- examples/ml_perf/main.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 5d94ddcc..200d91ac 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -215,17 +215,17 @@ def generator(dataset, training=False): if do_eval: eval_generator = generator(eval_ds, training=False) - logger.debug("Inspecting one batch of data...") - for first_batch in train_generator: - logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) - logger.debug( - "Small embedding inputs:%s", - first_batch[0]["small_emb_inputs"]["25_id"], - ) - logger.debug( - "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] - ) - break + # logger.debug("Inspecting one batch of data...") + # for first_batch in train_generator: + # logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) + # logger.debug( + # "Small embedding inputs:%s", + # first_batch[0]["small_emb_inputs"]["25_id"], + # ) + # logger.debug( + # "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] + # ) + # break logger.info("Successfully preprocessed one batch of data") # === Training === From 2879d4bb7ba4bc08d188505d779793e7f02b9595 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 15:59:47 +0530 Subject: [PATCH 201/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 274 +++++++++--------- 1 file changed, 137 insertions(+), 137 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 0ec5efd6..24032a9e 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,158 +24,158 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "cat_0", + "new_name": "0", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "cat_1", + "new_name": "1", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_2", + "new_name": "2", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_17", - }, - # { - # "name": "categorical-feature-18", - # "vocabulary_size": 20265, - # "feature_list_length": 6, - # "new_name": "cat_18", - # }, - # { - # "name": "categorical-feature-19", - # "vocabulary_size": 3, - # "feature_list_length": 1, - # "new_name": "cat_19", - # }, - # { - # "name": "categorical-feature-20", - # "vocabulary_size": 7122, - # "feature_list_length": 1, - # "new_name": "cat_20", - # }, - # { - # "name": "categorical-feature-21", - # "vocabulary_size": 1543, - # "feature_list_length": 1, - # "new_name": "cat_21", - # }, - # { - # "name": "categorical-feature-22", - # "vocabulary_size": 63, - # "feature_list_length": 1, - # "new_name": "cat_22", - # }, - # { - # "name": "categorical-feature-23", - # "vocabulary_size": 40000000, - # "feature_list_length": 7, - # "new_name": "cat_23", - # }, - # { - # "name": "categorical-feature-24", - # "vocabulary_size": 3067956, - # "feature_list_length": 3, - # "new_name": "cat_24", - # }, - # { - # "name": "categorical-feature-25", - # "vocabulary_size": 405282, - # "feature_list_length": 8, - # "new_name": "cat_25", - # }, - # { - # "name": "categorical-feature-26", - # "vocabulary_size": 10, - # "feature_list_length": 1, - # "new_name": "cat_26", - # }, - # { - # "name": "categorical-feature-27", - # "vocabulary_size": 2209, - # "feature_list_length": 6, - # "new_name": "cat_27", - # }, - # { - # "name": "categorical-feature-28", - # "vocabulary_size": 11938, - # "feature_list_length": 9, - # "new_name": "cat_28", - # }, - # { - # "name": "categorical-feature-29", - # "vocabulary_size": 155, - # "feature_list_length": 5, - # "new_name": "cat_29", - # }, - # { - # "name": "categorical-feature-30", - # "vocabulary_size": 4, - # "feature_list_length": 1, - # "new_name": "cat_30", - # }, - # { - # "name": "categorical-feature-31", - # "vocabulary_size": 976, - # "feature_list_length": 1, - # "new_name": "cat_31", - # }, - # { - # "name": "categorical-feature-32", - # "vocabulary_size": 14, - # "feature_list_length": 1, - # "new_name": "cat_32", - # }, - # { - # "name": "categorical-feature-33", - # "vocabulary_size": 40000000, - # "feature_list_length": 12, - # "new_name": "cat_33", - # }, - # { - # "name": "categorical-feature-34", - # "vocabulary_size": 40000000, - # "feature_list_length": 100, - # "new_name": "cat_34", - # }, - # { - # "name": "categorical-feature-35", - # "vocabulary_size": 40000000, - # "feature_list_length": 27, - # "new_name": "cat_35", - # }, - # { - # "name": "categorical-feature-36", - # "vocabulary_size": 590152, - # "feature_list_length": 10, - # "new_name": "cat_36", - # }, - # { - # "name": "categorical-feature-37", - # "vocabulary_size": 12973, - # "feature_list_length": 3, - # "new_name": "cat_37", - # }, - # { - # "name": "categorical-feature-38", - # "vocabulary_size": 108, - # "feature_list_length": 1, - # "new_name": "cat_38", - # }, - # { - # "name": "categorical-feature-39", - # "vocabulary_size": 36, - # "feature_list_length": 1, - # "new_name": "cat_39", - # }, + "new_name": "3", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "4", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "5", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "6", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "7", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "8", + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "9", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "10", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "11", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "12", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "13", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "14", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "15", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "16", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "17", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "18", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "19", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "20", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "21", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "22", + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "23", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "24", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "25", + }, ] # === Model === From 6c3f8ced5a900d32fb4b3542ecb01c25ea3d4dca Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 16:02:36 +0530 Subject: [PATCH 202/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 200d91ac..3db87c35 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -70,7 +70,7 @@ def main( feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( - name=f"{feature_name}_table", + name=feature_name, vocabulary_size=vocabulary_size, embedding_dim=model_cfg.embedding_dim, # TODO(abheesht): Verify. From 8a5cfcabf744f5654596e6de113fd029be766489 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 16:05:53 +0530 Subject: [PATCH 203/279] Debug --- examples/ml_perf/model.py | 2 ++ keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 9861c472..e956a000 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -196,9 +196,11 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: for small_emb_feature in small_emb_inputs.keys(): small_emb_input = small_emb_inputs[small_emb_feature] embedding_layer = self.small_embedding_layers[small_emb_feature] + embedding = embedding_layer(small_emb_input) embedding = ops.sum(embedding, axis=-2) small_embeddings.append(embedding) + small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index b13baad2..ae90f816 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -412,6 +412,8 @@ def sparsecore_build( feature_specs = config_conversion.keras_to_jte_feature_configs( self._sc_feature_configs ) + print(f"--->{self._sc_feature_configs=}") + print(f"--->{feature_specs=}") # Distribution for sparsecore operations. sparsecore_distribution, sparsecore_layout = ( From b8b493547865a6021b410da0c18eec5b1d7a2f1f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 23:18:09 +0530 Subject: [PATCH 204/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 24032a9e..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From b349d2616853f4c4ea40f31dfa8f0601b1e74dd5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 23:27:20 +0530 Subject: [PATCH 205/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..46e1ade0 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,157 +24,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "0", + "new_name": "cat_14", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "1", + "new_name": "cat_15", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "2", + "new_name": "cat_16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "3", + "new_name": "cat_17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "4", + "new_name": "cat_18", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "feature_list_length": 1, - "new_name": "5", + "new_name": "cat_19", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "feature_list_length": 1, - "new_name": "6", + "new_name": "cat_20", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "feature_list_length": 1, - "new_name": "7", + "new_name": "cat_21", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "feature_list_length": 1, - "new_name": "8", + "new_name": "cat_22", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "9", + "new_name": "cat_23", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "feature_list_length": 3, - "new_name": "10", + "new_name": "cat_24", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "feature_list_length": 8, - "new_name": "11", + "new_name": "cat_25", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "feature_list_length": 1, - "new_name": "12", + "new_name": "cat_26", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "feature_list_length": 6, - "new_name": "13", + "new_name": "cat_27", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "feature_list_length": 9, - "new_name": "14", + "new_name": "cat_28", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "feature_list_length": 5, - "new_name": "15", + "new_name": "cat_29", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "feature_list_length": 1, - "new_name": "16", + "new_name": "cat_30", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "feature_list_length": 1, - "new_name": "17", + "new_name": "cat_31", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "feature_list_length": 1, - "new_name": "18", + "new_name": "cat_32", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "feature_list_length": 12, - "new_name": "19", + "new_name": "cat_33", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "feature_list_length": 100, - "new_name": "20", + "new_name": "cat_34", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "feature_list_length": 27, - "new_name": "21", + "new_name": "cat_35", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "feature_list_length": 10, - "new_name": "22", + "new_name": "cat_36", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "23", + "new_name": "cat_37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "24", + "new_name": "cat_38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "25", + "new_name": "cat_39", }, ] From c6b9ce0d3682d2bd02e1539d1dd5f9df602872f1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 23:49:23 +0530 Subject: [PATCH 206/279] Debug --- examples/ml_perf/main.py | 110 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 4 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 3db87c35..205bd9b2 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -20,6 +20,89 @@ keras.utils.set_random_seed(SEED) keras.config.disable_traceback_filtering() +from typing import Any, Callable, List, Mapping +from jax_tpu_embedding.sparsecore.lib.flax import embed +from jax_tpu_embedding.sparsecore.lib.flax import embed_optimizer +from jax_tpu_embedding.sparsecore.lib.nn import embedding +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + +import jax +import jax.numpy as jnp + +_EMBEDDING_THRESHOLD = 21_000 +_LEARNING_RATE = 0.034 +_MAX_IDS_PER_PARTITION = 8192 +_MAX_UNIQUE_IDS_PER_PARTITION = 4096 +_EMBEDDING_SIZE = 128 +_BATCH_SIZE = 16896 + + +VOCAB_SIZES = [ + 40000000, 39060, 17295, 7424, 20265, 3, 7122, 1543, 63, 40000000, + 3067956, 405282, 10, 2209, 11938, 155, 4, 976, 14, 40000000, 40000000, + 40000000, 590152, 12973, 108, 36 +] + +from typing import Any, Callable, List, Mapping +from jax_tpu_embedding.sparsecore.lib.flax import embed +from jax_tpu_embedding.sparsecore.lib.flax import embed_optimizer +from jax_tpu_embedding.sparsecore.lib.nn import embedding +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + +import jax +import jax.numpy as jnp + + +def uniform_init(bound: float): + def init(key, shape, dtype=jnp.float_): + return jax.random.uniform( + key, + shape=shape, + dtype=dtype, + minval=-bound, + maxval=bound + ) + return init + + +def create_feature_specs( + vocab_sizes: List[int], +) -> tuple[ + Mapping[str, embedding_spec.TableSpec], + Mapping[str, embedding_spec.FeatureSpec], +]: + """Creates the feature specs for the DLRM model.""" + table_specs = {} + feature_specs = {} + for i, vocab_size in enumerate(vocab_sizes): + if vocab_size <= _EMBEDDING_THRESHOLD: + continue + + table_name = f"{i}" + feature_name = f"{i}" + bound = jnp.sqrt(1.0 / vocab_size) + table_spec = embedding_spec.TableSpec( + vocabulary_size=vocab_size, + embedding_dim=_EMBEDDING_SIZE, + initializer=uniform_init(bound), + optimizer=embedding_spec.AdagradOptimizerSpec( + learning_rate=_LEARNING_RATE + ), + combiner="sum", + name=table_name, + max_ids_per_partition=_MAX_IDS_PER_PARTITION, + max_unique_ids_per_partition=_MAX_UNIQUE_IDS_PER_PARTITION, + ) + feature_spec = embedding_spec.FeatureSpec( + table_spec=table_spec, + input_shape=(_BATCH_SIZE, 1), + output_shape=(_BATCH_SIZE, _EMBEDDING_SIZE), + name=feature_name, + ) + feature_specs[feature_name] = feature_spec + table_specs[table_name] = table_spec + return table_specs, feature_specs + class MetricLogger(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): @@ -271,10 +354,29 @@ def generator(dataset, training=False): model_cfg = config["model"] training_cfg = config["training"] - main( - ds_cfg=ds_cfg, - model_cfg=model_cfg, - training_cfg=training_cfg, + _, feature_spec = create_feature_specs( + vocab_sizes=vocab_sizes + ) + print("feature_spec:", feature_spec) + + # main( + # ds_cfg=ds_cfg, + # model_cfg=model_cfg, + # training_cfg=training_cfg, + # ) + + def _get_max_ids_per_partition(name: str, batch_size: int) -> int: + return _MAX_IDS_PER_PARTITION + + def _get_max_unique_ids_per_partition(name: str, batch_size: int) -> int: + return _MAX_UNIQUE_IDS_PER_PARTITION + + embedding.auto_stack_tables( + feature_spec, + global_device_count=jax.device_count(), + stack_to_max_ids_per_partition=_get_max_ids_per_partition, + stack_to_max_unique_ids_per_partition=_get_max_unique_ids_per_partition, + num_sc_per_device=2, ) logger.info("Train script finished") From 3bb4335ffccd64257299fac70e5826348f830f39 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 23:51:26 +0530 Subject: [PATCH 207/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 205bd9b2..8b7e881b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -355,7 +355,7 @@ def generator(dataset, training=False): training_cfg = config["training"] _, feature_spec = create_feature_specs( - vocab_sizes=vocab_sizes + vocab_sizes=VOCAB_SIZES ) print("feature_spec:", feature_spec) From d43447b00623978f452b4351dff3392ddb07ace3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 27 Oct 2025 23:58:09 +0530 Subject: [PATCH 208/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 52 ++++----- examples/ml_perf/main.py | 110 +----------------- 2 files changed, 30 insertions(+), 132 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 46e1ade0..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -24,157 +24,157 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "feature_list_length": 3, - "new_name": "cat_14", + "new_name": "0", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "feature_list_length": 2, - "new_name": "cat_15", + "new_name": "1", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, - "new_name": "cat_16", + "new_name": "2", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "feature_list_length": 2, - "new_name": "cat_17", + "new_name": "3", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "feature_list_length": 6, - "new_name": "cat_18", + "new_name": "4", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "feature_list_length": 1, - "new_name": "cat_19", + "new_name": "5", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "feature_list_length": 1, - "new_name": "cat_20", + "new_name": "6", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "feature_list_length": 1, - "new_name": "cat_21", + "new_name": "7", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "feature_list_length": 1, - "new_name": "cat_22", + "new_name": "8", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "feature_list_length": 7, - "new_name": "cat_23", + "new_name": "9", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "feature_list_length": 3, - "new_name": "cat_24", + "new_name": "10", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "feature_list_length": 8, - "new_name": "cat_25", + "new_name": "11", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "feature_list_length": 1, - "new_name": "cat_26", + "new_name": "12", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "feature_list_length": 6, - "new_name": "cat_27", + "new_name": "13", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "feature_list_length": 9, - "new_name": "cat_28", + "new_name": "14", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "feature_list_length": 5, - "new_name": "cat_29", + "new_name": "15", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "feature_list_length": 1, - "new_name": "cat_30", + "new_name": "16", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "feature_list_length": 1, - "new_name": "cat_31", + "new_name": "17", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "feature_list_length": 1, - "new_name": "cat_32", + "new_name": "18", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "feature_list_length": 12, - "new_name": "cat_33", + "new_name": "19", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "feature_list_length": 100, - "new_name": "cat_34", + "new_name": "20", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "feature_list_length": 27, - "new_name": "cat_35", + "new_name": "21", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "feature_list_length": 10, - "new_name": "cat_36", + "new_name": "22", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "feature_list_length": 3, - "new_name": "cat_37", + "new_name": "23", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "feature_list_length": 1, - "new_name": "cat_38", + "new_name": "24", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "feature_list_length": 1, - "new_name": "cat_39", + "new_name": "25", }, ] diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8b7e881b..3db87c35 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -20,89 +20,6 @@ keras.utils.set_random_seed(SEED) keras.config.disable_traceback_filtering() -from typing import Any, Callable, List, Mapping -from jax_tpu_embedding.sparsecore.lib.flax import embed -from jax_tpu_embedding.sparsecore.lib.flax import embed_optimizer -from jax_tpu_embedding.sparsecore.lib.nn import embedding -from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec - -import jax -import jax.numpy as jnp - -_EMBEDDING_THRESHOLD = 21_000 -_LEARNING_RATE = 0.034 -_MAX_IDS_PER_PARTITION = 8192 -_MAX_UNIQUE_IDS_PER_PARTITION = 4096 -_EMBEDDING_SIZE = 128 -_BATCH_SIZE = 16896 - - -VOCAB_SIZES = [ - 40000000, 39060, 17295, 7424, 20265, 3, 7122, 1543, 63, 40000000, - 3067956, 405282, 10, 2209, 11938, 155, 4, 976, 14, 40000000, 40000000, - 40000000, 590152, 12973, 108, 36 -] - -from typing import Any, Callable, List, Mapping -from jax_tpu_embedding.sparsecore.lib.flax import embed -from jax_tpu_embedding.sparsecore.lib.flax import embed_optimizer -from jax_tpu_embedding.sparsecore.lib.nn import embedding -from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec - -import jax -import jax.numpy as jnp - - -def uniform_init(bound: float): - def init(key, shape, dtype=jnp.float_): - return jax.random.uniform( - key, - shape=shape, - dtype=dtype, - minval=-bound, - maxval=bound - ) - return init - - -def create_feature_specs( - vocab_sizes: List[int], -) -> tuple[ - Mapping[str, embedding_spec.TableSpec], - Mapping[str, embedding_spec.FeatureSpec], -]: - """Creates the feature specs for the DLRM model.""" - table_specs = {} - feature_specs = {} - for i, vocab_size in enumerate(vocab_sizes): - if vocab_size <= _EMBEDDING_THRESHOLD: - continue - - table_name = f"{i}" - feature_name = f"{i}" - bound = jnp.sqrt(1.0 / vocab_size) - table_spec = embedding_spec.TableSpec( - vocabulary_size=vocab_size, - embedding_dim=_EMBEDDING_SIZE, - initializer=uniform_init(bound), - optimizer=embedding_spec.AdagradOptimizerSpec( - learning_rate=_LEARNING_RATE - ), - combiner="sum", - name=table_name, - max_ids_per_partition=_MAX_IDS_PER_PARTITION, - max_unique_ids_per_partition=_MAX_UNIQUE_IDS_PER_PARTITION, - ) - feature_spec = embedding_spec.FeatureSpec( - table_spec=table_spec, - input_shape=(_BATCH_SIZE, 1), - output_shape=(_BATCH_SIZE, _EMBEDDING_SIZE), - name=feature_name, - ) - feature_specs[feature_name] = feature_spec - table_specs[table_name] = table_spec - return table_specs, feature_specs - class MetricLogger(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): @@ -354,29 +271,10 @@ def generator(dataset, training=False): model_cfg = config["model"] training_cfg = config["training"] - _, feature_spec = create_feature_specs( - vocab_sizes=VOCAB_SIZES - ) - print("feature_spec:", feature_spec) - - # main( - # ds_cfg=ds_cfg, - # model_cfg=model_cfg, - # training_cfg=training_cfg, - # ) - - def _get_max_ids_per_partition(name: str, batch_size: int) -> int: - return _MAX_IDS_PER_PARTITION - - def _get_max_unique_ids_per_partition(name: str, batch_size: int) -> int: - return _MAX_UNIQUE_IDS_PER_PARTITION - - embedding.auto_stack_tables( - feature_spec, - global_device_count=jax.device_count(), - stack_to_max_ids_per_partition=_get_max_ids_per_partition, - stack_to_max_unique_ids_per_partition=_get_max_unique_ids_per_partition, - num_sc_per_device=2, + main( + ds_cfg=ds_cfg, + model_cfg=model_cfg, + training_cfg=training_cfg, ) logger.info("Train script finished") From dd716a8ba029d24bd9d98f40940b7e7cf2aad42a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 00:03:15 +0530 Subject: [PATCH 209/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 3db87c35..93c19f3d 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -238,7 +238,7 @@ def generator(dataset, training=False): # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - # verbose=0, + verbose=0, ) logger.info("Training finished") From a39c3f7094c414a3bc7e2a5e2947044de55836a6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 00:23:43 +0530 Subject: [PATCH 210/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..24032a9e 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index ae90f816..b13baad2 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -412,8 +412,6 @@ def sparsecore_build( feature_specs = config_conversion.keras_to_jte_feature_configs( self._sc_feature_configs ) - print(f"--->{self._sc_feature_configs=}") - print(f"--->{feature_specs=}") # Distribution for sparsecore operations. sparsecore_distribution, sparsecore_layout = ( From 8da410885d856d580907b29035cec2da83e4f166 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 01:13:58 +0530 Subject: [PATCH 211/279] Debug --- examples/ml_perf/dataloader.py | 3 ++- examples/ml_perf/main.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 49708cd3..e7b92d51 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -190,6 +190,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) + dataset = dataset.shard(num_shards=process_count, index=process_index) logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( @@ -213,7 +214,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.shuffle(shuffle_buffer, seed=SEED) dataset = dataset.batch( - self.batch_size, + self.batch_size // num_processes, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 93c19f3d..22812d50 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -188,7 +188,7 @@ def main( # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: - train_ds = distribution.distribute_dataset(train_ds) + # train_ds = distribution.distribute_dataset(train_ds) if do_eval: eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False From 02a1e6513062a51911c7064b749aa85b4bd49ea6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 01:15:58 +0530 Subject: [PATCH 212/279] Debug --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index e7b92d51..c9fda693 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -190,7 +190,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - dataset = dataset.shard(num_shards=process_count, index=process_index) + dataset = dataset.shard(num_shards=num_processes, index=process_index) logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( From 443daf9004c069976a6a983bd8e1d7984d892c6c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 01:18:11 +0530 Subject: [PATCH 213/279] Debug --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index c9fda693..c66c3bb4 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -190,7 +190,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - dataset = dataset.shard(num_shards=num_processes, index=process_index) + dataset = dataset.shard(num_shards=num_processes, index=process_id) logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( From 02d1d2c2b07f68064d064eca2b8f45bb962b43e6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 01:21:30 +0530 Subject: [PATCH 214/279] Debug --- examples/ml_perf/dataloader.py | 4 ++-- examples/ml_perf/main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index c66c3bb4..78b4a0a3 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -190,7 +190,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - dataset = dataset.shard(num_shards=num_processes, index=process_id) + # dataset = dataset.shard(num_shards=num_processes, index=process_id) logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( @@ -214,7 +214,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.shuffle(shuffle_buffer, seed=SEED) dataset = dataset.batch( - self.batch_size // num_processes, + self.batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 22812d50..93c19f3d 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -188,7 +188,7 @@ def main( # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: - # train_ds = distribution.distribute_dataset(train_ds) + train_ds = distribution.distribute_dataset(train_ds) if do_eval: eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False From 819b9c846c12f05f9622626ad6be04ca049e9d86 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 01:24:56 +0530 Subject: [PATCH 215/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 24032a9e..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From d6a8b76a2b92c3c957dffdad6fc15da683235e31 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 01:28:11 +0530 Subject: [PATCH 216/279] Debug --- examples/ml_perf/dataloader.py | 4 ++-- examples/ml_perf/main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 78b4a0a3..c66c3bb4 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -190,7 +190,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - # dataset = dataset.shard(num_shards=num_processes, index=process_id) + dataset = dataset.shard(num_shards=num_processes, index=process_id) logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( @@ -214,7 +214,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.shuffle(shuffle_buffer, seed=SEED) dataset = dataset.batch( - self.batch_size, + self.batch_size // num_processes, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 93c19f3d..22812d50 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -188,7 +188,7 @@ def main( # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: - train_ds = distribution.distribute_dataset(train_ds) + # train_ds = distribution.distribute_dataset(train_ds) if do_eval: eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False From c8c1eccff9f84a895a27d2a8ac29dd346ecb7e62 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 07:33:33 +0530 Subject: [PATCH 217/279] Debug --- examples/ml_perf/main.py | 8 ++++---- examples/ml_perf/model.py | 27 +++++++++++++++------------ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 22812d50..a80b06c1 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -198,13 +198,13 @@ def generator(dataset, training=False): large embedding features. """ for features, labels in dataset: - preprocessed_large_embeddings = model.embedding_layer.preprocess( - features["large_emb_inputs"], training=training - ) + # preprocessed_large_embeddings = model.embedding_layer.preprocess( + # features["large_emb_inputs"], training=training + # ) x = { "dense_input": features["dense_input"], - "large_emb_inputs": preprocessed_large_embeddings, + # "large_emb_inputs": preprocessed_large_embeddings, "small_emb_inputs": features["small_emb_inputs"], } y = labels diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index e956a000..5f860825 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -113,16 +113,16 @@ def __init__( ) logging.debug("Initialised Bottom MLP: %s", self.bottom_mlp) # Distributed embeddings for large embedding tables - self.embedding_layer = keras_rs.layers.DistributedEmbedding( - feature_configs=large_emb_feature_configs, - table_stacking="auto", - auto_stack_kwargs=auto_stack_kwargs, - dtype=dtype, - name="embedding_layer", - ) - logging.debug( - "Initialised `DistributedEmbedding` layer: %s", self.embedding_layer - ) + # self.embedding_layer = keras_rs.layers.DistributedEmbedding( + # feature_configs=large_emb_feature_configs, + # table_stacking="auto", + # auto_stack_kwargs=auto_stack_kwargs, + # dtype=dtype, + # name="embedding_layer", + # ) + # logging.debug( + # "Initialised `DistributedEmbedding` layer: %s", self.embedding_layer + # ) # Embedding layers for small embedding tables self.small_embedding_layers = None if small_emb_features: @@ -188,7 +188,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Embed features. dense_output = self.bottom_mlp(dense_input) - large_embeddings = self.embedding_layer(large_emb_inputs) + # large_embeddings = self.embedding_layer(large_emb_inputs) small_embeddings = None if self.small_emb_features: small_embeddings = [] @@ -204,7 +204,10 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction - to_concatenate = [dense_output, *large_embeddings.values()] + to_concatenate = [ + dense_output, + # *large_embeddings.values() + ] if small_embeddings is not None: to_concatenate += [small_embeddings] x = ops.concatenate(to_concatenate, axis=-1) From 60a0c93fc120e01ffaae93c083e02a88a275c881 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 08:03:57 +0530 Subject: [PATCH 218/279] Debug --- examples/ml_perf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 5f860825..e00bf5d9 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -184,7 +184,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: """ # Inputs dense_input = inputs["dense_input"] - large_emb_inputs = inputs["large_emb_inputs"] + # large_emb_inputs = inputs["large_emb_inputs"] # Embed features. dense_output = self.bottom_mlp(dense_input) From c79a08ad7f0db01fc6a3097b9e76617235f86a17 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 08:07:43 +0530 Subject: [PATCH 219/279] Debug --- examples/ml_perf/main.py | 8 ++++---- examples/ml_perf/model.py | 26 +++++++++++++------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a80b06c1..22812d50 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -198,13 +198,13 @@ def generator(dataset, training=False): large embedding features. """ for features, labels in dataset: - # preprocessed_large_embeddings = model.embedding_layer.preprocess( - # features["large_emb_inputs"], training=training - # ) + preprocessed_large_embeddings = model.embedding_layer.preprocess( + features["large_emb_inputs"], training=training + ) x = { "dense_input": features["dense_input"], - # "large_emb_inputs": preprocessed_large_embeddings, + "large_emb_inputs": preprocessed_large_embeddings, "small_emb_inputs": features["small_emb_inputs"], } y = labels diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index e00bf5d9..51172966 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -113,16 +113,16 @@ def __init__( ) logging.debug("Initialised Bottom MLP: %s", self.bottom_mlp) # Distributed embeddings for large embedding tables - # self.embedding_layer = keras_rs.layers.DistributedEmbedding( - # feature_configs=large_emb_feature_configs, - # table_stacking="auto", - # auto_stack_kwargs=auto_stack_kwargs, - # dtype=dtype, - # name="embedding_layer", - # ) - # logging.debug( - # "Initialised `DistributedEmbedding` layer: %s", self.embedding_layer - # ) + self.embedding_layer = keras_rs.layers.DistributedEmbedding( + feature_configs=large_emb_feature_configs, + table_stacking="auto", + auto_stack_kwargs=auto_stack_kwargs, + dtype=dtype, + name="embedding_layer", + ) + logging.debug( + "Initialised `DistributedEmbedding` layer: %s", self.embedding_layer + ) # Embedding layers for small embedding tables self.small_embedding_layers = None if small_emb_features: @@ -184,11 +184,11 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: """ # Inputs dense_input = inputs["dense_input"] - # large_emb_inputs = inputs["large_emb_inputs"] + large_emb_inputs = inputs["large_emb_inputs"] # Embed features. dense_output = self.bottom_mlp(dense_input) - # large_embeddings = self.embedding_layer(large_emb_inputs) + large_embeddings = self.embedding_layer(large_emb_inputs) small_embeddings = None if self.small_emb_features: small_embeddings = [] @@ -206,7 +206,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Interaction to_concatenate = [ dense_output, - # *large_embeddings.values() + *large_embeddings.values() ] if small_embeddings is not None: to_concatenate += [small_embeddings] From 815528094e65e4a286075ee4f6ab2dc4ed9a6f76 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 08:25:29 +0530 Subject: [PATCH 220/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..24032a9e 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From f34ba98d0601cee6dc8473feacce78bfabb1846c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 08:30:57 +0530 Subject: [PATCH 221/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 24032a9e..f7328a22 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16986 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 8541e6eb34a32401305a236b01394427a25fac2a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 08:35:12 +0530 Subject: [PATCH 222/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index f7328a22..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16986 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 6f31887cb43a2b62ac88327c582349d5251a0b56 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 10:00:35 +0530 Subject: [PATCH 223/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 22812d50..db5e1e4a 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -238,7 +238,7 @@ def generator(dataset, training=False): # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - verbose=0, + # verbose=0, ) logger.info("Training finished") From 12e50233a0c6aed7620cad2af38f5953f78b5074 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 10:04:42 +0530 Subject: [PATCH 224/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index db5e1e4a..22812d50 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -238,7 +238,7 @@ def generator(dataset, training=False): # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - # verbose=0, + verbose=0, ) logger.info("Training finished") From 4cd51b7e7927ab0ed8e82ca0905b2d533177e94e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 10:15:06 +0530 Subject: [PATCH 225/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..24032a9e 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 3e104eee9b680ba8907b69de9fe8ee859e8692b9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 10:25:17 +0530 Subject: [PATCH 226/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 24032a9e..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From c599cf90f8f1c7bbce44d510b2b3c668a62e84ed Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 13:58:12 +0530 Subject: [PATCH 227/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 288 +++++++++--------- 1 file changed, 144 insertions(+), 144 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..fa887062 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -26,156 +26,156 @@ "feature_list_length": 3, "new_name": "0", }, - { - "name": "categorical-feature-15", - "vocabulary_size": 39060, - "feature_list_length": 2, - "new_name": "1", - }, + # { + # "name": "categorical-feature-15", + # "vocabulary_size": 39060, + # "feature_list_length": 2, + # "new_name": "1", + # }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "feature_list_length": 1, "new_name": "2", }, - { - "name": "categorical-feature-17", - "vocabulary_size": 7424, - "feature_list_length": 2, - "new_name": "3", - }, - { - "name": "categorical-feature-18", - "vocabulary_size": 20265, - "feature_list_length": 6, - "new_name": "4", - }, - { - "name": "categorical-feature-19", - "vocabulary_size": 3, - "feature_list_length": 1, - "new_name": "5", - }, - { - "name": "categorical-feature-20", - "vocabulary_size": 7122, - "feature_list_length": 1, - "new_name": "6", - }, - { - "name": "categorical-feature-21", - "vocabulary_size": 1543, - "feature_list_length": 1, - "new_name": "7", - }, - { - "name": "categorical-feature-22", - "vocabulary_size": 63, - "feature_list_length": 1, - "new_name": "8", - }, - { - "name": "categorical-feature-23", - "vocabulary_size": 40000000, - "feature_list_length": 7, - "new_name": "9", - }, - { - "name": "categorical-feature-24", - "vocabulary_size": 3067956, - "feature_list_length": 3, - "new_name": "10", - }, - { - "name": "categorical-feature-25", - "vocabulary_size": 405282, - "feature_list_length": 8, - "new_name": "11", - }, - { - "name": "categorical-feature-26", - "vocabulary_size": 10, - "feature_list_length": 1, - "new_name": "12", - }, - { - "name": "categorical-feature-27", - "vocabulary_size": 2209, - "feature_list_length": 6, - "new_name": "13", - }, - { - "name": "categorical-feature-28", - "vocabulary_size": 11938, - "feature_list_length": 9, - "new_name": "14", - }, - { - "name": "categorical-feature-29", - "vocabulary_size": 155, - "feature_list_length": 5, - "new_name": "15", - }, - { - "name": "categorical-feature-30", - "vocabulary_size": 4, - "feature_list_length": 1, - "new_name": "16", - }, - { - "name": "categorical-feature-31", - "vocabulary_size": 976, - "feature_list_length": 1, - "new_name": "17", - }, - { - "name": "categorical-feature-32", - "vocabulary_size": 14, - "feature_list_length": 1, - "new_name": "18", - }, - { - "name": "categorical-feature-33", - "vocabulary_size": 40000000, - "feature_list_length": 12, - "new_name": "19", - }, - { - "name": "categorical-feature-34", - "vocabulary_size": 40000000, - "feature_list_length": 100, - "new_name": "20", - }, - { - "name": "categorical-feature-35", - "vocabulary_size": 40000000, - "feature_list_length": 27, - "new_name": "21", - }, - { - "name": "categorical-feature-36", - "vocabulary_size": 590152, - "feature_list_length": 10, - "new_name": "22", - }, - { - "name": "categorical-feature-37", - "vocabulary_size": 12973, - "feature_list_length": 3, - "new_name": "23", - }, - { - "name": "categorical-feature-38", - "vocabulary_size": 108, - "feature_list_length": 1, - "new_name": "24", - }, - { - "name": "categorical-feature-39", - "vocabulary_size": 36, - "feature_list_length": 1, - "new_name": "25", - }, + # { + # "name": "categorical-feature-17", + # "vocabulary_size": 7424, + # "feature_list_length": 2, + # "new_name": "3", + # }, + # { + # "name": "categorical-feature-18", + # "vocabulary_size": 20265, + # "feature_list_length": 6, + # "new_name": "4", + # }, + # { + # "name": "categorical-feature-19", + # "vocabulary_size": 3, + # "feature_list_length": 1, + # "new_name": "5", + # }, + # { + # "name": "categorical-feature-20", + # "vocabulary_size": 7122, + # "feature_list_length": 1, + # "new_name": "6", + # }, + # { + # "name": "categorical-feature-21", + # "vocabulary_size": 1543, + # "feature_list_length": 1, + # "new_name": "7", + # }, + # { + # "name": "categorical-feature-22", + # "vocabulary_size": 63, + # "feature_list_length": 1, + # "new_name": "8", + # }, + # { + # "name": "categorical-feature-23", + # "vocabulary_size": 40000000, + # "feature_list_length": 7, + # "new_name": "9", + # }, + # { + # "name": "categorical-feature-24", + # "vocabulary_size": 3067956, + # "feature_list_length": 3, + # "new_name": "10", + # }, + # { + # "name": "categorical-feature-25", + # "vocabulary_size": 405282, + # "feature_list_length": 8, + # "new_name": "11", + # }, + # { + # "name": "categorical-feature-26", + # "vocabulary_size": 10, + # "feature_list_length": 1, + # "new_name": "12", + # }, + # { + # "name": "categorical-feature-27", + # "vocabulary_size": 2209, + # "feature_list_length": 6, + # "new_name": "13", + # }, + # { + # "name": "categorical-feature-28", + # "vocabulary_size": 11938, + # "feature_list_length": 9, + # "new_name": "14", + # }, + # { + # "name": "categorical-feature-29", + # "vocabulary_size": 155, + # "feature_list_length": 5, + # "new_name": "15", + # }, + # { + # "name": "categorical-feature-30", + # "vocabulary_size": 4, + # "feature_list_length": 1, + # "new_name": "16", + # }, + # { + # "name": "categorical-feature-31", + # "vocabulary_size": 976, + # "feature_list_length": 1, + # "new_name": "17", + # }, + # { + # "name": "categorical-feature-32", + # "vocabulary_size": 14, + # "feature_list_length": 1, + # "new_name": "18", + # }, + # { + # "name": "categorical-feature-33", + # "vocabulary_size": 40000000, + # "feature_list_length": 12, + # "new_name": "19", + # }, + # { + # "name": "categorical-feature-34", + # "vocabulary_size": 40000000, + # "feature_list_length": 100, + # "new_name": "20", + # }, + # { + # "name": "categorical-feature-35", + # "vocabulary_size": 40000000, + # "feature_list_length": 27, + # "new_name": "21", + # }, + # { + # "name": "categorical-feature-36", + # "vocabulary_size": 590152, + # "feature_list_length": 10, + # "new_name": "22", + # }, + # { + # "name": "categorical-feature-37", + # "vocabulary_size": 12973, + # "feature_list_length": 3, + # "new_name": "23", + # }, + # { + # "name": "categorical-feature-38", + # "vocabulary_size": 108, + # "feature_list_length": 1, + # "new_name": "24", + # }, + # { + # "name": "categorical-feature-39", + # "vocabulary_size": 36, + # "feature_list_length": 1, + # "new_name": "25", + # }, ] # === Model === From fe1fb190795287cf5c283444080eb8d7265a0b9c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 14:01:49 +0530 Subject: [PATCH 228/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index fa887062..fd49c290 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -74,24 +74,24 @@ # "feature_list_length": 1, # "new_name": "8", # }, - # { - # "name": "categorical-feature-23", - # "vocabulary_size": 40000000, - # "feature_list_length": 7, - # "new_name": "9", - # }, - # { - # "name": "categorical-feature-24", - # "vocabulary_size": 3067956, - # "feature_list_length": 3, - # "new_name": "10", - # }, - # { - # "name": "categorical-feature-25", - # "vocabulary_size": 405282, - # "feature_list_length": 8, - # "new_name": "11", - # }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "9", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "10", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "11", + }, # { # "name": "categorical-feature-26", # "vocabulary_size": 10, From 355154ff5ceefaac749f12a6c7a08e9122b5e3fb Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:38:12 +0530 Subject: [PATCH 229/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 192 +++++++++--------- 1 file changed, 96 insertions(+), 96 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index fd49c290..7eca174a 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -38,42 +38,42 @@ "feature_list_length": 1, "new_name": "2", }, - # { - # "name": "categorical-feature-17", - # "vocabulary_size": 7424, - # "feature_list_length": 2, - # "new_name": "3", - # }, - # { - # "name": "categorical-feature-18", - # "vocabulary_size": 20265, - # "feature_list_length": 6, - # "new_name": "4", - # }, - # { - # "name": "categorical-feature-19", - # "vocabulary_size": 3, - # "feature_list_length": 1, - # "new_name": "5", - # }, - # { - # "name": "categorical-feature-20", - # "vocabulary_size": 7122, - # "feature_list_length": 1, - # "new_name": "6", - # }, - # { - # "name": "categorical-feature-21", - # "vocabulary_size": 1543, - # "feature_list_length": 1, - # "new_name": "7", - # }, - # { - # "name": "categorical-feature-22", - # "vocabulary_size": 63, - # "feature_list_length": 1, - # "new_name": "8", - # }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "3", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "4", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "5", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "6", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "7", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "8", + }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, @@ -92,48 +92,48 @@ "feature_list_length": 8, "new_name": "11", }, - # { - # "name": "categorical-feature-26", - # "vocabulary_size": 10, - # "feature_list_length": 1, - # "new_name": "12", - # }, - # { - # "name": "categorical-feature-27", - # "vocabulary_size": 2209, - # "feature_list_length": 6, - # "new_name": "13", - # }, - # { - # "name": "categorical-feature-28", - # "vocabulary_size": 11938, - # "feature_list_length": 9, - # "new_name": "14", - # }, - # { - # "name": "categorical-feature-29", - # "vocabulary_size": 155, - # "feature_list_length": 5, - # "new_name": "15", - # }, - # { - # "name": "categorical-feature-30", - # "vocabulary_size": 4, - # "feature_list_length": 1, - # "new_name": "16", - # }, - # { - # "name": "categorical-feature-31", - # "vocabulary_size": 976, - # "feature_list_length": 1, - # "new_name": "17", - # }, - # { - # "name": "categorical-feature-32", - # "vocabulary_size": 14, - # "feature_list_length": 1, - # "new_name": "18", - # }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "12", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "13", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "14", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "15", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "16", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "17", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "18", + }, # { # "name": "categorical-feature-33", # "vocabulary_size": 40000000, @@ -158,24 +158,24 @@ # "feature_list_length": 10, # "new_name": "22", # }, - # { - # "name": "categorical-feature-37", - # "vocabulary_size": 12973, - # "feature_list_length": 3, - # "new_name": "23", - # }, - # { - # "name": "categorical-feature-38", - # "vocabulary_size": 108, - # "feature_list_length": 1, - # "new_name": "24", - # }, - # { - # "name": "categorical-feature-39", - # "vocabulary_size": 36, - # "feature_list_length": 1, - # "new_name": "25", - # }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "23", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "24", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "25", + }, ] # === Model === From b1171c3b9102de0b0cfad030e406e0775e350ef9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:41:09 +0530 Subject: [PATCH 230/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7eca174a..36a79225 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -26,12 +26,12 @@ "feature_list_length": 3, "new_name": "0", }, - # { - # "name": "categorical-feature-15", - # "vocabulary_size": 39060, - # "feature_list_length": 2, - # "new_name": "1", - # }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "1", + }, { "name": "categorical-feature-16", "vocabulary_size": 17295, From ff94989d568080ddfce86a7b358be213f6f671da Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:43:40 +0530 Subject: [PATCH 231/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 36a79225..c195de85 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -152,12 +152,12 @@ # "feature_list_length": 27, # "new_name": "21", # }, - # { - # "name": "categorical-feature-36", - # "vocabulary_size": 590152, - # "feature_list_length": 10, - # "new_name": "22", - # }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "22", + }, { "name": "categorical-feature-37", "vocabulary_size": 12973, From 83598f744b2ac85f2662527a5e586e645762282d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:46:22 +0530 Subject: [PATCH 232/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index c195de85..bc02a273 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -134,12 +134,12 @@ "feature_list_length": 1, "new_name": "18", }, - # { - # "name": "categorical-feature-33", - # "vocabulary_size": 40000000, - # "feature_list_length": 12, - # "new_name": "19", - # }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "19", + }, # { # "name": "categorical-feature-34", # "vocabulary_size": 40000000, From d5d5946af3833ae02f30d8a81d8500d80cf7d918 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:48:58 +0530 Subject: [PATCH 233/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index bc02a273..4b2f8339 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -140,12 +140,12 @@ "feature_list_length": 12, "new_name": "19", }, - # { - # "name": "categorical-feature-34", - # "vocabulary_size": 40000000, - # "feature_list_length": 100, - # "new_name": "20", - # }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "20", + }, # { # "name": "categorical-feature-35", # "vocabulary_size": 40000000, From fe77e0874c8f524b8a015ecd00bb3924833722b0 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:52:06 +0530 Subject: [PATCH 234/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 4b2f8339..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -146,12 +146,12 @@ "feature_list_length": 100, "new_name": "20", }, - # { - # "name": "categorical-feature-35", - # "vocabulary_size": 40000000, - # "feature_list_length": 27, - # "new_name": "21", - # }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "21", + }, { "name": "categorical-feature-36", "vocabulary_size": 590152, From dfe4019ee3b002cacbf1707e390e356d516345cf Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 15:54:56 +0530 Subject: [PATCH 235/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..93523691 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -158,24 +158,24 @@ "feature_list_length": 10, "new_name": "22", }, - { - "name": "categorical-feature-37", - "vocabulary_size": 12973, - "feature_list_length": 3, - "new_name": "23", - }, - { - "name": "categorical-feature-38", - "vocabulary_size": 108, - "feature_list_length": 1, - "new_name": "24", - }, - { - "name": "categorical-feature-39", - "vocabulary_size": 36, - "feature_list_length": 1, - "new_name": "25", - }, + # { + # "name": "categorical-feature-37", + # "vocabulary_size": 12973, + # "feature_list_length": 3, + # "new_name": "23", + # }, + # { + # "name": "categorical-feature-38", + # "vocabulary_size": 108, + # "feature_list_length": 1, + # "new_name": "24", + # }, + # { + # "name": "categorical-feature-39", + # "vocabulary_size": 36, + # "feature_list_length": 1, + # "new_name": "25", + # }, ] # === Model === From af25d5968f9fcb60598ae4d74dd5be6236e0a86d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 16:00:21 +0530 Subject: [PATCH 236/279] Debug --- .../ml_perf/configs/v6e_16_full_dataset.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 93523691..6d1d9d03 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -152,30 +152,30 @@ "feature_list_length": 27, "new_name": "21", }, - { - "name": "categorical-feature-36", - "vocabulary_size": 590152, - "feature_list_length": 10, - "new_name": "22", - }, - # { - # "name": "categorical-feature-37", - # "vocabulary_size": 12973, - # "feature_list_length": 3, - # "new_name": "23", - # }, # { - # "name": "categorical-feature-38", - # "vocabulary_size": 108, - # "feature_list_length": 1, - # "new_name": "24", - # }, - # { - # "name": "categorical-feature-39", - # "vocabulary_size": 36, - # "feature_list_length": 1, - # "new_name": "25", + # "name": "categorical-feature-36", + # "vocabulary_size": 590152, + # "feature_list_length": 10, + # "new_name": "22", # }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "23", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "24", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "25", + }, ] # === Model === From 1c53c29717b83e61bb638ae39a0096b3e4ba63b3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 16:05:55 +0530 Subject: [PATCH 237/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 6d1d9d03..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -152,12 +152,12 @@ "feature_list_length": 27, "new_name": "21", }, - # { - # "name": "categorical-feature-36", - # "vocabulary_size": 590152, - # "feature_list_length": 10, - # "new_name": "22", - # }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "22", + }, { "name": "categorical-feature-37", "vocabulary_size": 12973, From b1f478ec4021ccc99d6069bd44ae78c47293f62c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 16:19:13 +0530 Subject: [PATCH 238/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..a6b1b721 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -134,12 +134,12 @@ "feature_list_length": 1, "new_name": "18", }, - { - "name": "categorical-feature-33", - "vocabulary_size": 40000000, - "feature_list_length": 12, - "new_name": "19", - }, + # { + # "name": "categorical-feature-33", + # "vocabulary_size": 40000000, + # "feature_list_length": 12, + # "new_name": "19", + # }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, From 46f57330240efd0c0df628120f7f5a277f7eeba7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 17:04:37 +0530 Subject: [PATCH 239/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index a6b1b721..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -134,12 +134,12 @@ "feature_list_length": 1, "new_name": "18", }, - # { - # "name": "categorical-feature-33", - # "vocabulary_size": 40000000, - # "feature_list_length": 12, - # "new_name": "19", - # }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "19", + }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, From 534e7e31d8ad1fb26bfe91c8643963a33716f08d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 17:05:14 +0530 Subject: [PATCH 240/279] Debug --- examples/ml_perf/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 51172966..fbaf7cf3 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -5,6 +5,7 @@ from keras import ops import keras_rs +import jax Tensor: TypeAlias = Any @@ -215,6 +216,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Predictions outputs = self.top_mlp(x) + jax.debug.print("outputs={}", outputs) return outputs def _get_mlp_layers( From b39c6d847faed738ce33e95dfec720b27e5bbaef Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 17:09:04 +0530 Subject: [PATCH 241/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 22812d50..319d3bfe 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -130,7 +130,7 @@ def main( optimizer=keras.optimizers.Adagrad( learning_rate=training_cfg.learning_rate ), - metrics=[keras.metrics.BinaryAccuracy()], + # metrics=[keras.metrics.BinaryAccuracy()], ) logger.info("Initialised model: %s", model) From b9e7d71b40cef95954af0b06a0be99306f82acc2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 17:21:32 +0530 Subject: [PATCH 242/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..328bb9e0 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 32768 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 929d7fd39a4431c35876f264ec4c88b7a0520906 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 17:28:13 +0530 Subject: [PATCH 243/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 328bb9e0..f5ee4ca4 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -184,8 +184,8 @@ model_config.embedding_dim = 128 model_config.allow_id_dropping = True model_config.embedding_threshold = 21000 -model_config.max_ids_per_partition = 8192 -model_config.max_unique_ids_per_partition = 4096 +model_config.max_ids_per_partition = 8192 * 2 +model_config.max_unique_ids_per_partition = 4096 * 2 model_config.learning_rate = 0.0034 # MLP @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 32768 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From e51b434069978a8a194c85b183b44d42caea86d5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 17:38:43 +0530 Subject: [PATCH 244/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index f5ee4ca4..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -184,8 +184,8 @@ model_config.embedding_dim = 128 model_config.allow_id_dropping = True model_config.embedding_threshold = 21000 -model_config.max_ids_per_partition = 8192 * 2 -model_config.max_unique_ids_per_partition = 4096 * 2 +model_config.max_ids_per_partition = 8192 +model_config.max_unique_ids_per_partition = 4096 model_config.learning_rate = 0.0034 # MLP From f1f9fed4e3230835d9ebb1345bb48cc41cf5230b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 20:18:02 +0530 Subject: [PATCH 245/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..24032a9e 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From eb78ec534aea95d63c6e07171f2d9586723f6f91 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 20:21:15 +0530 Subject: [PATCH 246/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 24032a9e..7128ca26 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 6f9ab3aadbcc0f7e279582afe6cfee090bcf23a1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 21:11:25 +0530 Subject: [PATCH 247/279] Debug --- examples/ml_perf/configs/v6e_16.py | 16 ++++++++++------ examples/ml_perf/dataloader.py | 4 ++-- examples/ml_perf/main.py | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 4166ab3d..e5e37198 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -3,12 +3,14 @@ config = Config() # === Experiment metadata === -config.experiment_name = "v6e_16" -config.model_dir = "./v6e_16" +config.experiment_name = "v6e_16_full_dataset" +config.model_dir = "./v6e_16_full_dataset" # === Dataset === dataset_config = Config() dataset_config.file_pattern = None +dataset_config.val_file_pattern = None + # Features dataset_config.label = "clicked" dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] @@ -192,10 +194,12 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 -# Set `num_steps` in the main config file instead of num_epochs, because we are -# using a Python generator. -training_config.num_steps = 20 +training_config.global_batch_size = 16896 +# Set `num_steps` instead of `num_epochs`, because we are using a Python +# generator. +training_config.num_steps = 10 +training_config.eval_freq = 5 +training_config.num_eval_steps = 10 # === Assign all configs to the root config === config.dataset = dataset_config diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index c66c3bb4..78b4a0a3 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -190,7 +190,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Important to specify shuffle = False here to ensure all processes have # the same order. dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) - dataset = dataset.shard(num_shards=num_processes, index=process_id) + # dataset = dataset.shard(num_shards=num_processes, index=process_id) logger.info("List of input files: %s", [f for f in dataset]) dataset = tf.data.TFRecordDataset( @@ -214,7 +214,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.shuffle(shuffle_buffer, seed=SEED) dataset = dataset.batch( - self.batch_size // num_processes, + self.batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 319d3bfe..43a20d78 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -188,7 +188,7 @@ def main( # See note here: # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: - # train_ds = distribution.distribute_dataset(train_ds) + train_ds = distribution.distribute_dataset(train_ds) if do_eval: eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False From 7c65c027b3421546283a80355cce20eae1ae9e48 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 21:14:54 +0530 Subject: [PATCH 248/279] Debug --- examples/ml_perf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index fbaf7cf3..9e477f22 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -216,7 +216,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Predictions outputs = self.top_mlp(x) - jax.debug.print("outputs={}", outputs) + # jax.debug.print("outputs={}", outputs) return outputs def _get_mlp_layers( From fa79280d0155307f1f82f90111b6089e53547379 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 28 Oct 2025 23:58:21 +0530 Subject: [PATCH 249/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7128ca26..e66c55a1 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 * 2 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 6a9978c79551e3e7dd7fdd671bd36a584baef057 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 00:36:05 +0530 Subject: [PATCH 250/279] Debug --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index e5e37198..9bf5590f 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -194,7 +194,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 * 2 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From aa5c7aaafa4e5f344d6b821a743846e305a5e009 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 00:44:32 +0530 Subject: [PATCH 251/279] Debug --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 9bf5590f..e5e37198 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -194,7 +194,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 * 2 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 24e0267a99141b51b070793370eaacf2f839d18d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 05:59:54 +0530 Subject: [PATCH 252/279] Debug --- examples/ml_perf/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 43a20d78..63de1c9c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -9,6 +9,9 @@ import keras_rs +import jax +jax.config.update("jax_debug_nans", True) + from .dataloader import DataLoader from .model import DLRMDCNV2 From 8d8bea9b1ce065fbac33cfe4f3e93af08b1bcd3d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 06:23:14 +0530 Subject: [PATCH 253/279] Full run --- .../ml_perf/configs/v6e_16_full_dataset.py | 6 ++--- examples/ml_perf/main.py | 24 +++++++------------ examples/ml_perf/model.py | 7 +----- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index e66c55a1..7a7ed65d 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -10,7 +10,7 @@ dataset_config = Config() dataset_config.file_pattern = ( "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" - "train-0000[0-3]-of-01024tfrecord" + "train-*-of-01024tfrecord" ) dataset_config.val_file_pattern = None # The path which we are reading from already has the batched dataset. @@ -199,10 +199,10 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 * 2 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 10 +training_config.num_steps = 28000 training_config.eval_freq = 5 training_config.num_eval_steps = 10 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 63de1c9c..3627800f 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -9,9 +9,6 @@ import keras_rs -import jax -jax.config.update("jax_debug_nans", True) - from .dataloader import DataLoader from .model import DLRMDCNV2 @@ -142,17 +139,14 @@ def main( # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. - # if ds_cfg.val_file_pattern: - # steps_per_epoch = training_cfg.eval_freq - # epochs = training_cfg.num_steps // training_cfg.eval_freq - # do_eval = True - # else: - # steps_per_epoch = training_cfg.num_steps - # epochs = 1 - # do_eval = False - steps_per_epoch = training_cfg.num_steps - epochs = 2 - do_eval = False + if ds_cfg.val_file_pattern: + steps_per_epoch = training_cfg.eval_freq + epochs = training_cfg.num_steps // training_cfg.eval_freq + do_eval = True + else: + steps_per_epoch = training_cfg.num_steps + epochs = 1 + do_eval = False logger.info(f"{steps_per_epoch=}, {epochs=}, {do_eval=}") @@ -241,7 +235,7 @@ def generator(dataset, training=False): # callbacks=[MetricLogger()], # validation_steps=training_cfg.num_eval_steps, # validation_freq=1, - verbose=0, + # verbose=0, ) logger.info("Training finished") diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 9e477f22..e956a000 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -5,7 +5,6 @@ from keras import ops import keras_rs -import jax Tensor: TypeAlias = Any @@ -205,10 +204,7 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction - to_concatenate = [ - dense_output, - *large_embeddings.values() - ] + to_concatenate = [dense_output, *large_embeddings.values()] if small_embeddings is not None: to_concatenate += [small_embeddings] x = ops.concatenate(to_concatenate, axis=-1) @@ -216,7 +212,6 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: # Predictions outputs = self.top_mlp(x) - # jax.debug.print("outputs={}", outputs) return outputs def _get_mlp_layers( From 248b1504045f66c2fb374e57dfc3d2ae114c6731 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 09:46:35 +0530 Subject: [PATCH 254/279] Full run --- examples/ml_perf/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 3627800f..83be5760 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -130,7 +130,10 @@ def main( optimizer=keras.optimizers.Adagrad( learning_rate=training_cfg.learning_rate ), - # metrics=[keras.metrics.BinaryAccuracy()], + metrics=[ + keras.metrics.BinaryAccuracy(), + keras.metrics.AUC() + ], ) logger.info("Initialised model: %s", model) From 4d73c5c728111773d85dcf6196abbadac3c42894 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 10:00:07 +0530 Subject: [PATCH 255/279] Full run --- examples/ml_perf/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 83be5760..b03187a0 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -2,6 +2,7 @@ import importlib import logging import os +import time os.environ["KERAS_BACKEND"] = "jax" @@ -132,7 +133,7 @@ def main( ), metrics=[ keras.metrics.BinaryAccuracy(), - keras.metrics.AUC() + keras.metrics.AUC(), ], ) logger.info("Initialised model: %s", model) @@ -230,6 +231,7 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") + t0 = time.perf_counter() model.fit( train_generator, # validation_data=eval_generator, @@ -240,7 +242,7 @@ def generator(dataset, training=False): # validation_freq=1, # verbose=0, ) - logger.info("Training finished") + logger.info("Training finished in %s seconds", time.perf_counter() - t0) if __name__ == "__main__": From 6ec11c184d52257d1f8e33cb6058b44274ff7ce9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 14:43:04 +0530 Subject: [PATCH 256/279] Add profiling statements --- examples/ml_perf/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index b03187a0..ccbf4793 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -9,6 +9,7 @@ import keras import keras_rs +import jax from .dataloader import DataLoader from .model import DLRMDCNV2 @@ -232,6 +233,7 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") t0 = time.perf_counter() + jax.profiler.start_trace("/tmp/ml-perf-benchmarking") model.fit( train_generator, # validation_data=eval_generator, @@ -242,6 +244,7 @@ def generator(dataset, training=False): # validation_freq=1, # verbose=0, ) + jax.profiler.stop_trace() logger.info("Training finished in %s seconds", time.perf_counter() - t0) From dc738a276cafee6e022aa560c38cf2f693f9314e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 14:47:00 +0530 Subject: [PATCH 257/279] Add profiling statements --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 7a7ed65d..230236b6 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 28000 +training_config.num_steps = 10 # 28000 training_config.eval_freq = 5 training_config.num_eval_steps = 10 From 7b6b03649f0f176c4b87407629ea191fe9e3298f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 19:44:26 +0530 Subject: [PATCH 258/279] Some dataloader options --- examples/ml_perf/dataloader.py | 6 ++++++ examples/ml_perf/main.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 78b4a0a3..83af0cdf 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -226,4 +226,10 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.prefetch(tf.data.AUTOTUNE) + # Random try, lol + options = tf.data.Options() + options.deterministic = False + options.threading.private_threadpool_size = 96 + dataset = dataset.with_options(options) + return dataset diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index ccbf4793..d75b234d 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -233,7 +233,7 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") t0 = time.perf_counter() - jax.profiler.start_trace("/tmp/ml-perf-benchmarking") + # jax.profiler.start_trace("/tmp/ml-perf-benchmarking") model.fit( train_generator, # validation_data=eval_generator, @@ -244,7 +244,7 @@ def generator(dataset, training=False): # validation_freq=1, # verbose=0, ) - jax.profiler.stop_trace() + # jax.profiler.stop_trace() logger.info("Training finished in %s seconds", time.perf_counter() - t0) From 4e950f06bd0fba5bd2337870248e4e2fba0c4a8b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 19:53:11 +0530 Subject: [PATCH 259/279] Some dataloader options --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 230236b6..d7073dac 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 10 # 28000 +training_config.num_steps = 28000 # 28000 training_config.eval_freq = 5 training_config.num_eval_steps = 10 From 657e8d5e63dab52f99a905090dc273324c73bd3d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 20:00:22 +0530 Subject: [PATCH 260/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- examples/ml_perf/dataloader.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index d7073dac..eb07c7e7 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 28000 # 28000 +training_config.num_steps = 5000 # 28000 training_config.eval_freq = 5 training_config.num_eval_steps = 10 diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 83af0cdf..78bedf3f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -226,10 +226,10 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.prefetch(tf.data.AUTOTUNE) - # Random try, lol - options = tf.data.Options() - options.deterministic = False - options.threading.private_threadpool_size = 96 - dataset = dataset.with_options(options) + # # Random try, lol + # options = tf.data.Options() + # options.deterministic = False + # options.threading.private_threadpool_size = 96 + # dataset = dataset.with_options(options) return dataset From 0ed121be68fb9e8c40441cb106e1ca3c1cac7eb9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 20:11:13 +0530 Subject: [PATCH 261/279] Debug --- examples/ml_perf/dataloader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 78bedf3f..83af0cdf 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -226,10 +226,10 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.prefetch(tf.data.AUTOTUNE) - # # Random try, lol - # options = tf.data.Options() - # options.deterministic = False - # options.threading.private_threadpool_size = 96 - # dataset = dataset.with_options(options) + # Random try, lol + options = tf.data.Options() + options.deterministic = False + options.threading.private_threadpool_size = 96 + dataset = dataset.with_options(options) return dataset From 0de1f0ba33679b71760ed78e86cf6c36d176fedf Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 20:17:54 +0530 Subject: [PATCH 262/279] Debug --- examples/ml_perf/dataloader.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 83af0cdf..0dbe6acf 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -225,11 +225,4 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.repeat() dataset = dataset.prefetch(tf.data.AUTOTUNE) - - # Random try, lol - options = tf.data.Options() - options.deterministic = False - options.threading.private_threadpool_size = 96 - dataset = dataset.with_options(options) - return dataset From 6cfc3fc5239690c8bda63be76838b4bed557d652 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 29 Oct 2025 20:19:12 +0530 Subject: [PATCH 263/279] Debug --- examples/ml_perf/configs/v6e_16_full_dataset.py | 2 +- examples/ml_perf/main.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index eb07c7e7..8301debc 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -202,7 +202,7 @@ training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. -training_config.num_steps = 5000 # 28000 +training_config.num_steps = 1000 # 28000 training_config.eval_freq = 5 training_config.num_eval_steps = 10 diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d75b234d..c6f7c51a 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -233,7 +233,7 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") t0 = time.perf_counter() - # jax.profiler.start_trace("/tmp/ml-perf-benchmarking") + jax.profiler.start_trace("/tmp/ml-perf-benchmarking/1000_steps") model.fit( train_generator, # validation_data=eval_generator, @@ -244,7 +244,7 @@ def generator(dataset, training=False): # validation_freq=1, # verbose=0, ) - # jax.profiler.stop_trace() + jax.profiler.stop_trace() logger.info("Training finished in %s seconds", time.perf_counter() - t0) From a1c65c4e0930e02354cd6507f1df2bc8d32d1677 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 00:30:36 +0530 Subject: [PATCH 264/279] Debug --- .../embedding/jax/distributed_embedding.py | 100 +++++++++--------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index b13baad2..6cbefd1c 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -437,7 +437,7 @@ def sparsecore_build( feature_specs, global_device_count, num_sc_per_device, - **self._auto_stack_kwargs, + # **self._auto_stack_kwargs, ) else: raise ValueError( @@ -631,55 +631,55 @@ def _sparsecore_preprocess( num_sc_per_device, ) - # if training: - # # Synchronize input statistics across all devices and update the - # # underlying stacked tables specs in the feature specs. - - # # Aggregate stats across all processes/devices via pmax. - # all_stats = multihost_utils.process_allgather(stats) - # # print("### all_stats", all_stats) - # # aggregated_stats = all_stats - # aggregated_stats = jax.tree.map( - # lambda x: np.max(x, axis=0), all_stats - # ) - - # # Check if stats changed enough to warrant action. - # stacked_table_specs = embedding.get_stacked_table_specs( - # self._config.feature_specs - # ) - # changed = any( - # np.max(aggregated_stats.max_ids_per_partition[stack_name]) - # > spec.max_ids_per_partition - # or np.max( - # aggregated_stats.max_unique_ids_per_partition[stack_name] - # ) - # > spec.max_unique_ids_per_partition - # or ( - # np.max( - # aggregated_stats.required_buffer_size_per_sc[stack_name] - # ) - # * num_sc_per_device - # ) - # > (spec.suggested_coo_buffer_size_per_device or 0) - # for stack_name, spec in stacked_table_specs.items() - # ) - - # # # Update configuration and repeat preprocessing if stats changed. - # if changed: - # embedding.update_preprocessing_parameters( - # self._config.feature_specs, - # aggregated_stats, - # num_sc_per_device, - # ) - - # # # Re-execute preprocessing with consistent input statistics. - # # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # # self._config.feature_specs, - # # samples, - # # local_device_count, - # # global_device_count, - # # num_sc_per_device, - # # ) + if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + + # Aggregate stats across all processes/devices via pmax. + all_stats = multihost_utils.process_allgather(stats) + # print("### all_stats", all_stats) + # aggregated_stats = all_stats + aggregated_stats = jax.tree.map( + lambda x: np.max(x, axis=0), all_stats + ) + + # Check if stats changed enough to warrant action. + stacked_table_specs = embedding.get_stacked_table_specs( + self._config.feature_specs + ) + changed = any( + np.max(aggregated_stats.max_ids_per_partition[stack_name]) + > spec.max_ids_per_partition + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) + > spec.max_unique_ids_per_partition + or ( + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) + * num_sc_per_device + ) + > (spec.suggested_coo_buffer_size_per_device or 0) + for stack_name, spec in stacked_table_specs.items() + ) + + # # Update configuration and repeat preprocessing if stats changed. + if changed: + embedding.update_preprocessing_parameters( + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, + ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 7b4b18da6883cedd88a90292b57cb989df4a3589 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 00:32:33 +0530 Subject: [PATCH 265/279] Debug --- examples/ml_perf/configs/v6e_16.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index e5e37198..7e879b84 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -3,8 +3,8 @@ config = Config() # === Experiment metadata === -config.experiment_name = "v6e_16_full_dataset" -config.model_dir = "./v6e_16_full_dataset" +config.experiment_name = "v6e_16" +config.model_dir = "./v6e_16" # === Dataset === dataset_config = Config() @@ -194,7 +194,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16896 +training_config.global_batch_size = 16384 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 4e5259fc006fdb1e43a7bc74d71960d2f44bb4c5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 00:34:32 +0530 Subject: [PATCH 266/279] Debug --- examples/ml_perf/configs/v6e_16.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 7e879b84..1c6eb45c 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -147,12 +147,12 @@ "feature_list_length": 27, "new_name": "21", }, - { - "name": "categorical-feature-36", - "vocabulary_size": 590152, - "feature_list_length": 10, - "new_name": "22", - }, + # { + # "name": "categorical-feature-36", + # "vocabulary_size": 590152, + # "feature_list_length": 10, + # "new_name": "22", + # }, { "name": "categorical-feature-37", "vocabulary_size": 12973, From 18b97f33c0abea57500950495bc90a0b692b4b96 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 00:35:28 +0530 Subject: [PATCH 267/279] Debug --- keras_rs/src/layers/embedding/jax/distributed_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 6cbefd1c..a2b6c254 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -5,6 +5,7 @@ from typing import Any, Mapping, Sequence, Union import jax +from jax.experimental import multihost_utils import keras import numpy as np from jax import numpy as jnp From 5c56e1be856024760130b621e5ecfeaa6f6145a0 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 00:44:14 +0530 Subject: [PATCH 268/279] Debug --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 1c6eb45c..4e9b5cd9 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -194,7 +194,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16384 * 2 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 37345dd8e30bedc6da2ab79add831db0d53fe8e0 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 00:46:34 +0530 Subject: [PATCH 269/279] Debug --- examples/ml_perf/configs/v6e_16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index 4e9b5cd9..e9b1313a 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -194,7 +194,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 * 2 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 10 From 06a0a79686981478cd4a867cd0254a50312ac0cd Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 07:18:46 +0530 Subject: [PATCH 270/279] Remove auto stack kwargs --- examples/ml_perf/main.py | 6 ---- examples/ml_perf/model.py | 4 --- .../embedding/jax/distributed_embedding.py | 34 ------------------- 3 files changed, 44 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index c6f7c51a..5463670b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -117,12 +117,6 @@ def main( top_mlp_dims=model_cfg.top_mlp_dims, num_dcn_layers=model_cfg.num_dcn_layers, dcn_projection_dim=model_cfg.dcn_projection_dim, - auto_stack_kwargs={ - "max_ids_per_partition": model_cfg.max_ids_per_partition, - "max_unique_ids_per_partition": ( - model_cfg.max_unique_ids_per_partition - ), - }, seed=SEED, dtype="float32", name="dlrm_dcn_v2", diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index e956a000..7f119c66 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -49,7 +49,6 @@ def __init__( top_mlp_dims: list[int], num_dcn_layers: int, dcn_projection_dim: int, - auto_stack_kwargs: dict[str, Any], seed: int | keras.random.SeedGenerator | None = None, dtype: str | None = None, name: str | None = None, @@ -116,7 +115,6 @@ def __init__( self.embedding_layer = keras_rs.layers.DistributedEmbedding( feature_configs=large_emb_feature_configs, table_stacking="auto", - auto_stack_kwargs=auto_stack_kwargs, dtype=dtype, name="embedding_layer", ) @@ -173,7 +171,6 @@ def __init__( self.top_mlp_dims = top_mlp_dims self.num_dcn_layers = num_dcn_layers self.dcn_projection_dim = dcn_projection_dim - self.auto_stack_kwargs = auto_stack_kwargs def call(self, inputs: dict[str, Tensor]) -> Tensor: """Forward pass of the model. @@ -280,7 +277,6 @@ def get_config(self): "top_mlp_dims": self.top_mlp_dims, "num_dcn_layers": self.num_dcn_layers, "dcn_projection_dim": self.dcn_projection_dim, - "auto_stack_kwargs": self.auto_stack_kwargs, "seed": self.seed, } ) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index a2b6c254..b8c91527 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -193,39 +193,6 @@ def __call__( class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding): """JAX implementation of the TPU embedding layer.""" - def __init__(self, **kwargs: Any): - # Pull out `auto_stack_kwargs` from `kwargs`. - auto_stack_kwargs = kwargs.pop("auto_stack_kwargs", {}) - - auto_stack_max_ids_per_partition = auto_stack_kwargs.pop( - "max_ids_per_partition", None - ) - auto_stack_max_unique_ids_per_partition = auto_stack_kwargs.pop( - "max_unique_ids_per_partition", None - ) - - # For `max_ids_per_partition` and `max_unique_ids_per_partition`, JTE's - # `auto_stack_tables` expects callables. - def _get_max_ids_per_partition(name: str, batch_size: int) -> int: - return auto_stack_max_ids_per_partition - - def _get_max_unique_ids_per_partition( - name: str, batch_size: int - ) -> int: - return auto_stack_max_unique_ids_per_partition - - if auto_stack_max_ids_per_partition is not None: - auto_stack_kwargs["stack_to_max_ids_per_partition"] = ( - _get_max_ids_per_partition - ) - if auto_stack_max_unique_ids_per_partition is not None: - auto_stack_kwargs["stack_to_max_unique_ids_per_partition"] = ( - _get_max_unique_ids_per_partition - ) - - self._auto_stack_kwargs = auto_stack_kwargs - super().__init__(**kwargs) - def _create_sparsecore_distribution( self, sparsecore_axis_name: str = "sparsecore" ) -> tuple[ @@ -438,7 +405,6 @@ def sparsecore_build( feature_specs, global_device_count, num_sc_per_device, - # **self._auto_stack_kwargs, ) else: raise ValueError( From 4bfbf9536ad48cb9f6f8347884ebf06e559792d5 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 07:34:27 +0530 Subject: [PATCH 271/279] Comment out profiling --- examples/ml_perf/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 5463670b..03970811 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -227,7 +227,7 @@ def generator(dataset, training=False): # === Training === logger.info("Training...") t0 = time.perf_counter() - jax.profiler.start_trace("/tmp/ml-perf-benchmarking/1000_steps") + # jax.profiler.start_trace("/tmp/ml-perf-benchmarking/1000_steps") model.fit( train_generator, # validation_data=eval_generator, @@ -238,7 +238,7 @@ def generator(dataset, training=False): # validation_freq=1, # verbose=0, ) - jax.profiler.stop_trace() + # jax.profiler.stop_trace() logger.info("Training finished in %s seconds", time.perf_counter() - t0) From a0578cc5876f224b1b24bc3b360df36b3650ccb9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 07:35:41 +0530 Subject: [PATCH 272/279] Comment out stat update for now --- .../ml_perf/configs/v6e_16_full_dataset.py | 2 +- .../embedding/jax/distributed_embedding.py | 96 +++++++++---------- 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/examples/ml_perf/configs/v6e_16_full_dataset.py b/examples/ml_perf/configs/v6e_16_full_dataset.py index 8301debc..d1817c44 100644 --- a/examples/ml_perf/configs/v6e_16_full_dataset.py +++ b/examples/ml_perf/configs/v6e_16_full_dataset.py @@ -199,7 +199,7 @@ # === Training === training_config = Config() training_config.learning_rate = 0.0034 -training_config.global_batch_size = 16384 +training_config.global_batch_size = 16896 # Set `num_steps` instead of `num_epochs`, because we are using a Python # generator. training_config.num_steps = 1000 # 28000 diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index d550bd18..a9cac408 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -623,55 +623,53 @@ def _sparsecore_preprocess( num_sc_per_device, ) - if training: - # Synchronize input statistics across all devices and update the - # underlying stacked tables specs in the feature specs. - - # Aggregate stats across all processes/devices via pmax. - all_stats = multihost_utils.process_allgather(stats) - # print("### all_stats", all_stats) - # aggregated_stats = all_stats - aggregated_stats = jax.tree.map( - lambda x: np.max(x, axis=0), all_stats - ) - - # Check if stats changed enough to warrant action. - stacked_table_specs = embedding.get_stacked_table_specs( - self._config.feature_specs - ) - changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) - > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) - > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) - > (spec.suggested_coo_buffer_size_per_device or 0) - for stack_name, spec in stacked_table_specs.items() - ) - - # # Update configuration and repeat preprocessing if stats changed. - if changed: - embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, - ) - - # # Re-execute preprocessing with consistent input statistics. - # preprocessed, _ = embedding_utils.stack_and_shard_samples( - # self._config.feature_specs, - # samples, - # local_device_count, - # global_device_count, - # num_sc_per_device, - # ) + # if training: + # # Synchronize input statistics across all devices and update the + # # underlying stacked tables specs in the feature specs. + + # # Aggregate stats across all processes/devices via pmax. + # all_stats = multihost_utils.process_allgather(stats) + # aggregated_stats = jax.tree.map( + # lambda x: np.max(x, axis=0), all_stats + # ) + + # # Check if stats changed enough to warrant action. + # stacked_table_specs = embedding.get_stacked_table_specs( + # self._config.feature_specs + # ) + # changed = any( + # np.max(aggregated_stats.max_ids_per_partition[stack_name]) + # > spec.max_ids_per_partition + # or np.max( + # aggregated_stats.max_unique_ids_per_partition[stack_name] + # ) + # > spec.max_unique_ids_per_partition + # or ( + # np.max( + # aggregated_stats.required_buffer_size_per_sc[stack_name] + # ) + # * num_sc_per_device + # ) + # > (spec.suggested_coo_buffer_size_per_device or 0) + # for stack_name, spec in stacked_table_specs.items() + # ) + + # # # Update configuration and repeat preprocessing if stats changed. + # if changed: + # embedding.update_preprocessing_parameters( + # self._config.feature_specs, + # aggregated_stats, + # num_sc_per_device, + # ) + + # # Re-execute preprocessing with consistent input statistics. + # preprocessed, _ = embedding_utils.stack_and_shard_samples( + # self._config.feature_specs, + # samples, + # local_device_count, + # global_device_count, + # num_sc_per_device, + # ) return {"inputs": preprocessed} From 05f8905a24705cd27b9001af260ec73f6879ae01 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 08:04:49 +0530 Subject: [PATCH 273/279] Debug --- examples/ml_perf/main.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 03970811..54f93168 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -10,6 +10,7 @@ import keras_rs import jax +from jax.experimental import checkify from .dataloader import DataLoader from .model import DLRMDCNV2 @@ -228,16 +229,17 @@ def generator(dataset, training=False): logger.info("Training...") t0 = time.perf_counter() # jax.profiler.start_trace("/tmp/ml-perf-benchmarking/1000_steps") - model.fit( - train_generator, - # validation_data=eval_generator, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - # callbacks=[MetricLogger()], - # validation_steps=training_cfg.num_eval_steps, - # validation_freq=1, - # verbose=0, - ) + model.predict(train_generator) + # model.fit( + # train_generator, + # # validation_data=eval_generator, + # epochs=epochs, + # steps_per_epoch=steps_per_epoch, + # # callbacks=[MetricLogger()], + # # validation_steps=training_cfg.num_eval_steps, + # # validation_freq=1, + # # verbose=0, + # ) # jax.profiler.stop_trace() logger.info("Training finished in %s seconds", time.perf_counter() - t0) From ff0625ffa68daf30778f47e902d1f4df8e03a93e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 08:16:35 +0530 Subject: [PATCH 274/279] Debug --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 54f93168..6a2070d6 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -229,7 +229,7 @@ def generator(dataset, training=False): logger.info("Training...") t0 = time.perf_counter() # jax.profiler.start_trace("/tmp/ml-perf-benchmarking/1000_steps") - model.predict(train_generator) + model.predict(train_generator, steps=steps_per_epoch) # model.fit( # train_generator, # # validation_data=eval_generator, From 36535dfdd32cf195a1cd796baa6229c94b4c4f63 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Thu, 30 Oct 2025 08:21:21 +0530 Subject: [PATCH 275/279] Debug --- examples/ml_perf/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 6a2070d6..79077c08 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -10,6 +10,7 @@ import keras_rs import jax +jax.config.update("jax_debug_nans", True) from jax.experimental import checkify from .dataloader import DataLoader From 27f9fd645dc15f7795dbe994da05a35a9d6edae9 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 31 Oct 2025 18:33:07 +0530 Subject: [PATCH 276/279] Add model.summary() --- examples/ml_perf/main.py | 215 +++++++++++------- .../embedding/jax/distributed_embedding.py | 1 - 2 files changed, 127 insertions(+), 89 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 79077c08..3989f12f 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -9,10 +9,8 @@ import keras import keras_rs -import jax -jax.config.update("jax_debug_nans", True) -from jax.experimental import checkify +# jax.config.update("jax_debug_nans", True) from .dataloader import DataLoader from .model import DLRMDCNV2 @@ -30,6 +28,83 @@ def on_train_batch_end(self, batch, logs=None): print("--->", logs["loss"]) +def _load_dataset( + model, + distribution, + ds_cfg, + training_cfg, + large_emb_features, + small_emb_features, + steps_per_epoch, + do_eval, + num_processes, +): + train_ds = DataLoader( + file_pattern=ds_cfg.file_pattern, + batch_size=training_cfg.global_batch_size, + file_batch_size=ds_cfg.get("file_batch_size", None), + dense_features=ds_cfg.dense, + large_emb_features=large_emb_features, + small_emb_features=small_emb_features, + label=ds_cfg.label, + num_steps=steps_per_epoch + 20, + training=True, + ).create_dataset( + process_id=distribution._process_id, + num_processes=num_processes, + shuffle_buffer=ds_cfg.get("shuffle_buffer", None), + ) + if do_eval: + eval_ds = DataLoader( + file_pattern=ds_cfg.val_file_pattern, + batch_size=training_cfg.global_batch_size, + file_batch_size=ds_cfg.get("file_batch_size", None), + dense_features=ds_cfg.dense, + large_emb_features=large_emb_features, + small_emb_features=small_emb_features, + label=ds_cfg.label, + num_steps=training_cfg.num_eval_steps, + repeat=True, + training=False, + ).create_dataset( + process_id=distribution._process_id, + num_processes=num_processes, + ) + # For the multi-host case, the dataset has to be distributed manually. + # See note here: + # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. + if num_processes > 1: + train_ds = distribution.distribute_dataset(train_ds) + if do_eval: + eval_ds = distribution.distribute_dataset(eval_ds) + distribution.auto_shard_dataset = False + + def generator(dataset, training=False): + """Converts tf.data Dataset to a Python generator and preprocesses + large embedding features. + """ + for features, labels in dataset: + preprocessed_large_embeddings = model.embedding_layer.preprocess( + features["large_emb_inputs"], training=training + ) + + x = { + "dense_input": features["dense_input"], + "large_emb_inputs": preprocessed_large_embeddings, + "small_emb_inputs": features["small_emb_inputs"], + } + y = labels + yield (x, y) + + logger.info("Preprocessing large embedding tables...") + train_gen = generator(train_ds, training=True) + if do_eval: + eval_gen = generator(eval_ds, training=False) + return train_gen, eval_gen + + return train_gen, None + + def main( ds_cfg, model_cfg, @@ -135,9 +210,6 @@ def main( ) logger.info("Initialised model: %s", model) - # === Load dataset === - logger.info("Loading dataset...") - # Keras does not have a straightforward way to log at a step-level instead # of epoch-level. So, we do a workaround here. if ds_cfg.val_file_pattern: @@ -148,99 +220,66 @@ def main( steps_per_epoch = training_cfg.num_steps epochs = 1 do_eval = False - logger.info(f"{steps_per_epoch=}, {epochs=}, {do_eval=}") - train_ds = DataLoader( - file_pattern=ds_cfg.file_pattern, - batch_size=training_cfg.global_batch_size, - file_batch_size=ds_cfg.get("file_batch_size", None), - dense_features=ds_cfg.dense, - large_emb_features=large_emb_features, - small_emb_features=small_emb_features, - label=ds_cfg.label, - num_steps=steps_per_epoch + 20, - training=True, - ).create_dataset( - process_id=distribution._process_id, - num_processes=num_processes, - shuffle_buffer=ds_cfg.get("shuffle_buffer", None), + # === Do one dummy forward pass on the model === + logger.info("Loading dummy dataset...") + dummy_gen, _ = _load_dataset( + model, + distribution, + ds_cfg, + training_cfg, + large_emb_features, + small_emb_features, + steps_per_epoch, + do_eval, + num_processes, ) - if do_eval: - eval_ds = DataLoader( - file_pattern=ds_cfg.val_file_pattern, - batch_size=training_cfg.global_batch_size, - file_batch_size=ds_cfg.get("file_batch_size", None), - dense_features=ds_cfg.dense, - large_emb_features=large_emb_features, - small_emb_features=small_emb_features, - label=ds_cfg.label, - num_steps=training_cfg.num_eval_steps, - repeat=True, - training=False, - ).create_dataset( - process_id=distribution._process_id, - num_processes=num_processes, - ) - # For the multi-host case, the dataset has to be distributed manually. - # See note here: - # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. - if num_processes > 1: - train_ds = distribution.distribute_dataset(train_ds) - if do_eval: - eval_ds = distribution.distribute_dataset(eval_ds) - distribution.auto_shard_dataset = False - def generator(dataset, training=False): - """Converts tf.data Dataset to a Python generator and preprocesses - large embedding features. - """ - for features, labels in dataset: - preprocessed_large_embeddings = model.embedding_layer.preprocess( - features["large_emb_inputs"], training=training - ) + logger.debug("Inspecting one batch of data...") + for first_batch in dummy_gen: + logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) + logger.debug( + "Small embedding inputs:%s", + first_batch[0]["small_emb_inputs"]["25_id"], + ) + logger.debug( + "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] + ) + break + logger.info("Successfully preprocessed one batch of data") - x = { - "dense_input": features["dense_input"], - "large_emb_inputs": preprocessed_large_embeddings, - "small_emb_inputs": features["small_emb_inputs"], - } - y = labels - yield (x, y) + logger.info("Doing one step of forward pass on the model...") + model.predict(dummy_gen, steps=steps_per_epoch) + logger.info("Model summary: %s", model.summary()) - logger.info("Preprocessing large embedding tables...") - train_generator = generator(train_ds, training=True) - if do_eval: - eval_generator = generator(eval_ds, training=False) - - # logger.debug("Inspecting one batch of data...") - # for first_batch in train_generator: - # logger.debug("Dense inputs:%s", first_batch[0]["dense_input"]) - # logger.debug( - # "Small embedding inputs:%s", - # first_batch[0]["small_emb_inputs"]["25_id"], - # ) - # logger.debug( - # "Large embedding inputs:%s", first_batch[0]["large_emb_inputs"] - # ) - # break - logger.info("Successfully preprocessed one batch of data") + # === Load dataset === + train_gen, eval_gen = _load_dataset( + model, + distribution, + ds_cfg, + training_cfg, + large_emb_features, + small_emb_features, + steps_per_epoch, + do_eval, + num_processes, + ) # === Training === logger.info("Training...") t0 = time.perf_counter() # jax.profiler.start_trace("/tmp/ml-perf-benchmarking/1000_steps") - model.predict(train_generator, steps=steps_per_epoch) - # model.fit( - # train_generator, - # # validation_data=eval_generator, - # epochs=epochs, - # steps_per_epoch=steps_per_epoch, - # # callbacks=[MetricLogger()], - # # validation_steps=training_cfg.num_eval_steps, - # # validation_freq=1, - # # verbose=0, - # ) + model.fit( + train_gen, + validation_data=eval_gen, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + # callbacks=[MetricLogger()], + # validation_steps=training_cfg.num_eval_steps, + # validation_freq=1, + # verbose=0, + ) # jax.profiler.stop_trace() logger.info("Training finished in %s seconds", time.perf_counter() - t0) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index a9cac408..80347b3d 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -6,7 +6,6 @@ from typing import Any, Mapping, Sequence, Union import jax -from jax.experimental import multihost_utils import keras import numpy as np from jax import numpy as jnp From 7c3c3e68a9f4352d25aaab7a3c9eaa70e5a12b0f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 31 Oct 2025 18:43:28 +0530 Subject: [PATCH 277/279] Add model.summary() --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 3989f12f..8cead985 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -250,7 +250,7 @@ def main( logger.info("Successfully preprocessed one batch of data") logger.info("Doing one step of forward pass on the model...") - model.predict(dummy_gen, steps=steps_per_epoch) + model.predict(dummy_gen, steps=1) logger.info("Model summary: %s", model.summary()) # === Load dataset === From e1187f0e058ed61bb54562c0652435dc0420362f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 31 Oct 2025 18:54:17 +0530 Subject: [PATCH 278/279] Workaround --- examples/ml_perf/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8cead985..565cadd5 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -250,7 +250,10 @@ def main( logger.info("Successfully preprocessed one batch of data") logger.info("Doing one step of forward pass on the model...") - model.predict(dummy_gen, steps=1) + # TODO: Use model.predict() directly. For some reason, it is not working, + # currently. + for batch in dummy_generator: + model.predict_on_batch(batch[0]) logger.info("Model summary: %s", model.summary()) # === Load dataset === From 260e5b1338bc7633a949c4c75e3f5eb49a0faa9c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Fri, 31 Oct 2025 18:56:56 +0530 Subject: [PATCH 279/279] Workaround --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 565cadd5..eee2e6a9 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -252,7 +252,7 @@ def main( logger.info("Doing one step of forward pass on the model...") # TODO: Use model.predict() directly. For some reason, it is not working, # currently. - for batch in dummy_generator: + for batch in dummy_gen: model.predict_on_batch(batch[0]) logger.info("Model summary: %s", model.summary())