5656def translate_keras_rs_configuration (
5757 feature_configs : types .Nested [FeatureConfig ],
5858 table_stacking : str | Sequence [str ] | Sequence [Sequence [str ]],
59+ num_replicas_in_sync : int ,
5960) -> tuple [
6061 types .Nested [tf .tpu .experimental .embedding .FeatureConfig ],
6162 tf .tpu .experimental .embedding .SparseCoreEmbeddingConfig ,
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
7273 """
7374 tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ] = {}
7475 feature_configs = keras .tree .map_structure (
75- lambda f : translate_keras_rs_feature_config (f , tables ), feature_configs
76+ lambda f : translate_keras_rs_feature_config (
77+ f , tables , num_replicas_in_sync
78+ ),
79+ feature_configs ,
7680 )
7781
7882 # max_ids_per_chip_per_sample
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
107111def translate_keras_rs_feature_config (
108112 feature_config : FeatureConfig ,
109113 tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ],
114+ num_replicas_in_sync : int ,
110115) -> tf .tpu .experimental .embedding .FeatureConfig :
111116 """Translates a Keras RS feature config to a TensorFlow TPU feature config.
112117
@@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
120125 Returns:
121126 The TensorFlow TPU feature config.
122127 """
128+ if num_replicas_in_sync <= 0 :
129+ raise ValueError (
130+ "`num_replicas_in_sync` must be positive, "
131+ f"but got { num_replicas_in_sync } ."
132+ )
133+
123134 table = tables .get (feature_config .table , None )
124135 if table is None :
125136 table = translate_keras_rs_table_config (feature_config .table )
126137 tables [feature_config .table ] = table
127138
139+ if len (feature_config .output_shape ) < 2 :
140+ raise ValueError (
141+ f"Invalid `output_shape` { feature_config .output_shape } in "
142+ f"`FeatureConfig` { feature_config } . It must have at least 2 "
143+ "dimensions: a batch dimension and an embedding dimension."
144+ )
145+
146+ # Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
147+ output_shape = list (feature_config .output_shape [0 :- 1 ])
148+
149+ batch_size = output_shape [0 ]
150+ per_replica_batch_size : int | None = None
151+ if batch_size is not None :
152+ if batch_size % num_replicas_in_sync != 0 :
153+ raise ValueError (
154+ f"Invalid `output_shape` { feature_config .output_shape } in "
155+ f"`FeatureConfig` { feature_config } . Batch size { batch_size } is "
156+ f"not a multiple of the number of TPUs { num_replicas_in_sync } ."
157+ )
158+ per_replica_batch_size = batch_size // num_replicas_in_sync
159+
160+ # TensorFlow's TPUEmbedding wants the per replica batch size.
161+ output_shape = [per_replica_batch_size ] + output_shape [1 :]
162+
128163 # max_sequence_length
129164 return tf .tpu .experimental .embedding .FeatureConfig (
130165 name = feature_config .name ,
131166 table = table ,
132- output_shape = feature_config .output_shape [
133- 0 :- 1
134- ], # exclude last dimension
167+ output_shape = output_shape ,
135168 )
136169
137170
0 commit comments