Skip to content

Commit fc91f77

Browse files
authored
Updated documentation regarding DistributedEmbedding batch size. (#138)
The batch size represented by the first dimension of `input_shape` and `output_shape` in `FeatureConfig` is the global batch size.
1 parent 73331a6 commit fc91f77

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,14 @@ class DistributedEmbedding(keras.layers.Layer):
146146
feature1 = keras_rs.layers.FeatureConfig(
147147
name="feature1",
148148
table=table1,
149-
input_shape=(PER_REPLICA_BATCH_SIZE,),
150-
output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
149+
input_shape=(GLOBAL_BATCH_SIZE,),
150+
output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
151151
)
152152
feature2 = keras_rs.layers.FeatureConfig(
153153
name="feature2",
154154
table=table2,
155-
input_shape=(PER_REPLICA_BATCH_SIZE,),
156-
output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
155+
input_shape=(GLOBAL_BATCH_SIZE,),
156+
output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
157157
)
158158
159159
feature_configs = {

keras_rs/src/layers/embedding/distributed_embedding_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ class FeatureConfig:
102102
input_shape: The input shape of the feature. The feature fed into the
103103
layer has to match the shape. Note that for ragged dimensions in the
104104
input, the dimension provided here presents the maximum value;
105-
anything larger will be truncated.
105+
anything larger will be truncated. Also note that the first
106+
dimension represents the global batch size. For example, on TPU,
107+
this represents the total number of samples that are dispatched to
108+
all the TPUs connected to the current host.
106109
output_shape: The output shape of the feature activation. What is
107110
returned by the embedding layer has to match this shape.
108111
"""

0 commit comments

Comments
 (0)