@@ -47,6 +47,12 @@ def experimental_distribute_dataset(self, dataset, options=None):
4747 return dataset
4848
4949
50+ class JaxDummyStrategy (DummyStrategy ):
51+ @property
52+ def num_replicas_in_sync (self ):
53+ return len (jax .devices ("tpu" ))
54+
55+
5056class DistributedEmbeddingTest (testing .TestCase , parameterized .TestCase ):
5157 def setUp (self ):
5258 super ().setUp ()
@@ -80,9 +86,15 @@ def setUp(self):
8086 )
8187 print ("### num_replicas" , self ._strategy .num_replicas_in_sync )
8288 self .addCleanup (tf .tpu .experimental .shutdown_tpu_system , resolver )
89+ elif keras .backend .backend () == "jax" and self .on_tpu :
90+ self ._strategy = JaxDummyStrategy ()
8391 else :
8492 self ._strategy = DummyStrategy ()
8593
94+ self .batch_size = (
95+ BATCH_SIZE_PER_CORE * self ._strategy .num_replicas_in_sync
96+ )
97+
8698 def run_with_strategy (self , fn , * args , jit_compile = False ):
8799 """Wrapper for running a function under a strategy."""
88100
@@ -120,31 +132,31 @@ def get_embedding_config(self, input_type, placement):
120132 feature_group ["feature1" ] = config .FeatureConfig (
121133 name = "feature1" ,
122134 table = feature1_table ,
123- input_shape = (BATCH_SIZE_PER_CORE , sequence_length ),
124- output_shape = (BATCH_SIZE_PER_CORE , FEATURE1_EMBEDDING_OUTPUT_DIM ),
135+ input_shape = (self . batch_size , sequence_length ),
136+ output_shape = (self . batch_size , FEATURE1_EMBEDDING_OUTPUT_DIM ),
125137 )
126138 feature_group ["feature2" ] = config .FeatureConfig (
127139 name = "feature2" ,
128140 table = feature2_table ,
129- input_shape = (BATCH_SIZE_PER_CORE , sequence_length ),
130- output_shape = (BATCH_SIZE_PER_CORE , FEATURE2_EMBEDDING_OUTPUT_DIM ),
141+ input_shape = (self . batch_size , sequence_length ),
142+ output_shape = (self . batch_size , FEATURE2_EMBEDDING_OUTPUT_DIM ),
131143 )
132144 return {"feature_group" : feature_group }
133145
134146 def create_inputs_weights_and_labels (
135- self , batch_size , input_type , feature_configs , backend = None
147+ self , input_type , feature_configs , backend = None
136148 ):
137149 backend = backend or keras .backend .backend ()
138150
139151 if input_type == "dense" :
140152
141153 def create_tensor (feature_config , op ):
142- sequence_length = feature_config .input_shape [- 1 ]
143- return op ((batch_size , sequence_length ))
154+ return op (feature_config .input_shape )
144155
145156 elif input_type == "ragged" :
146157
147158 def create_tensor (feature_config , op ):
159+ batch_size = feature_config .input_shape [0 ]
148160 sequence_length = feature_config .input_shape [- 1 ]
149161 row_lengths = [
150162 1 + (i % sequence_length ) for i in range (batch_size )
@@ -157,6 +169,7 @@ def create_tensor(feature_config, op):
157169 elif input_type == "sparse" and backend == "tensorflow" :
158170
159171 def create_tensor (feature_config , op ):
172+ batch_size = feature_config .input_shape [0 ]
160173 sequence_length = feature_config .input_shape [- 1 ]
161174 indices = [[i , i % sequence_length ] for i in range (batch_size )]
162175 return tf .sparse .reorder (
@@ -170,6 +183,7 @@ def create_tensor(feature_config, op):
170183 elif input_type == "sparse" and backend == "jax" :
171184
172185 def create_tensor (feature_config , op ):
186+ batch_size = feature_config .input_shape [0 ]
173187 sequence_length = feature_config .input_shape [- 1 ]
174188 indices = [[i , i % sequence_length ] for i in range (batch_size )]
175189 return jax_sparse .BCOO (
@@ -197,9 +211,7 @@ def create_tensor(feature_config, op):
197211 feature_configs ,
198212 )
199213 labels = keras .tree .map_structure (
200- lambda fc : np .ones (
201- (batch_size ,) + fc .output_shape [1 :], dtype = np .float32
202- ),
214+ lambda fc : np .ones (fc .output_shape , dtype = np .float32 ),
203215 feature_configs ,
204216 )
205217 return inputs , weights , labels
@@ -228,10 +240,9 @@ def test_basics(self, input_type, placement):
228240 ):
229241 self .skipTest ("Ragged and sparse are not compilable on TPU." )
230242
231- batch_size = self ._strategy .num_replicas_in_sync * BATCH_SIZE_PER_CORE
232243 feature_configs = self .get_embedding_config (input_type , placement )
233244 inputs , weights , _ = self .create_inputs_weights_and_labels (
234- batch_size , input_type , feature_configs
245+ input_type , feature_configs
235246 )
236247
237248 if placement == "sparsecore" and not self .on_tpu :
@@ -264,11 +275,11 @@ def test_basics(self, input_type, placement):
264275
265276 self .assertEqual (
266277 res ["feature_group" ]["feature1" ].shape ,
267- (batch_size , FEATURE1_EMBEDDING_OUTPUT_DIM ),
278+ (self . batch_size , FEATURE1_EMBEDDING_OUTPUT_DIM ),
268279 )
269280 self .assertEqual (
270281 res ["feature_group" ]["feature2" ].shape ,
271- (batch_size , FEATURE2_EMBEDDING_OUTPUT_DIM ),
282+ (self . batch_size , FEATURE2_EMBEDDING_OUTPUT_DIM ),
272283 )
273284
274285 @parameterized .named_parameters (
@@ -289,16 +300,15 @@ def test_model_fit(self, input_type, use_weights):
289300 f"{ input_type } not supported on { keras .backend .backend ()} "
290301 )
291302
292- batch_size = self ._strategy .num_replicas_in_sync * BATCH_SIZE_PER_CORE
293303 feature_configs = self .get_embedding_config (input_type , self .placement )
294304 train_inputs , train_weights , train_labels = (
295305 self .create_inputs_weights_and_labels (
296- batch_size , input_type , feature_configs , backend = "tensorflow"
306+ input_type , feature_configs , backend = "tensorflow"
297307 )
298308 )
299309 test_inputs , test_weights , test_labels = (
300310 self .create_inputs_weights_and_labels (
301- batch_size , input_type , feature_configs , backend = "tensorflow"
311+ input_type , feature_configs , backend = "tensorflow"
302312 )
303313 )
304314
@@ -482,12 +492,11 @@ def test_correctness(
482492 feature_config = config .FeatureConfig (
483493 name = "feature" ,
484494 table = table ,
485- input_shape = (BATCH_SIZE_PER_CORE , sequence_length ),
486- output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
495+ input_shape = (self . batch_size , sequence_length ),
496+ output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
487497 )
488498
489- batch_size = self ._strategy .num_replicas_in_sync * BATCH_SIZE_PER_CORE
490- num_repeats = batch_size // 2
499+ num_repeats = self .batch_size // 2
491500 if input_type == "dense" and input_rank == 1 :
492501 inputs = keras .ops .convert_to_tensor ([2 , 3 ] * num_repeats )
493502 weights = keras .ops .convert_to_tensor ([1.0 , 2.0 ] * num_repeats )
@@ -512,14 +521,14 @@ def test_correctness(
512521 tf .SparseTensor (
513522 indices ,
514523 [1 , 2 , 3 , 4 , 5 ] * num_repeats ,
515- dense_shape = (batch_size , 4 ),
524+ dense_shape = (self . batch_size , 4 ),
516525 )
517526 )
518527 weights = tf .sparse .reorder (
519528 tf .SparseTensor (
520529 indices ,
521530 [1.0 , 1.0 , 2.0 , 3.0 , 4.0 ] * num_repeats ,
522- dense_shape = (batch_size , 4 ),
531+ dense_shape = (self . batch_size , 4 ),
523532 )
524533 )
525534 elif keras .backend .backend () == "jax" :
@@ -528,15 +537,15 @@ def test_correctness(
528537 jnp .asarray ([1 , 2 , 3 , 4 , 5 ] * num_repeats ),
529538 jnp .asarray (indices ),
530539 ),
531- shape = (batch_size , 4 ),
540+ shape = (self . batch_size , 4 ),
532541 unique_indices = True ,
533542 )
534543 weights = jax_sparse .BCOO (
535544 (
536545 jnp .asarray ([1.0 , 1.0 , 2.0 , 3.0 , 4.0 ] * num_repeats ),
537546 jnp .asarray (indices ),
538547 ),
539- shape = (batch_size , 4 ),
548+ shape = (self . batch_size , 4 ),
540549 unique_indices = True ,
541550 )
542551 else :
@@ -600,7 +609,7 @@ def test_correctness(
600609 layer .__call__ , inputs , weights , jit_compile = jit_compile
601610 )
602611
603- self .assertEqual (res .shape , (batch_size , EMBEDDING_OUTPUT_DIM ))
612+ self .assertEqual (res .shape , (self . batch_size , EMBEDDING_OUTPUT_DIM ))
604613
605614 tables = layer .get_embedding_tables ()
606615 emb = tables ["table" ]
@@ -644,26 +653,25 @@ def test_shared_table(self):
644653 "feature1" : config .FeatureConfig (
645654 name = "feature1" ,
646655 table = table1 ,
647- input_shape = (BATCH_SIZE_PER_CORE , 1 ),
648- output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
656+ input_shape = (self . batch_size , 1 ),
657+ output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
649658 ),
650659 "feature2" : config .FeatureConfig (
651660 name = "feature2" ,
652661 table = table1 ,
653- input_shape = (BATCH_SIZE_PER_CORE , 1 ),
654- output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
662+ input_shape = (self . batch_size , 1 ),
663+ output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
655664 ),
656665 "feature3" : config .FeatureConfig (
657666 name = "feature3" ,
658667 table = table1 ,
659- input_shape = (BATCH_SIZE_PER_CORE , 1 ),
660- output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
668+ input_shape = (self . batch_size , 1 ),
669+ output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
661670 ),
662671 }
663672
664- batch_size = self ._strategy .num_replicas_in_sync * BATCH_SIZE_PER_CORE
665673 inputs , _ , _ = self .create_inputs_weights_and_labels (
666- batch_size , "dense" , embedding_config
674+ "dense" , embedding_config
667675 )
668676
669677 with self ._strategy .scope ():
@@ -676,13 +684,13 @@ def test_shared_table(self):
676684 self .assertLen (layer .trainable_variables , 1 )
677685
678686 self .assertEqual (
679- res ["feature1" ].shape , (batch_size , EMBEDDING_OUTPUT_DIM )
687+ res ["feature1" ].shape , (self . batch_size , EMBEDDING_OUTPUT_DIM )
680688 )
681689 self .assertEqual (
682- res ["feature2" ].shape , (batch_size , EMBEDDING_OUTPUT_DIM )
690+ res ["feature2" ].shape , (self . batch_size , EMBEDDING_OUTPUT_DIM )
683691 )
684692 self .assertEqual (
685- res ["feature3" ].shape , (batch_size , EMBEDDING_OUTPUT_DIM )
693+ res ["feature3" ].shape , (self . batch_size , EMBEDDING_OUTPUT_DIM )
686694 )
687695
688696 def test_mixed_placement (self ):
@@ -719,26 +727,25 @@ def test_mixed_placement(self):
719727 "feature1" : config .FeatureConfig (
720728 name = "feature1" ,
721729 table = table1 ,
722- input_shape = (BATCH_SIZE_PER_CORE , 1 ),
723- output_shape = (BATCH_SIZE_PER_CORE , embedding_output_dim1 ),
730+ input_shape = (self . batch_size , 1 ),
731+ output_shape = (self . batch_size , embedding_output_dim1 ),
724732 ),
725733 "feature2" : config .FeatureConfig (
726734 name = "feature2" ,
727735 table = table2 ,
728- input_shape = (BATCH_SIZE_PER_CORE , 1 ),
729- output_shape = (BATCH_SIZE_PER_CORE , embedding_output_dim2 ),
736+ input_shape = (self . batch_size , 1 ),
737+ output_shape = (self . batch_size , embedding_output_dim2 ),
730738 ),
731739 "feature3" : config .FeatureConfig (
732740 name = "feature3" ,
733741 table = table3 ,
734- input_shape = (BATCH_SIZE_PER_CORE , 1 ),
735- output_shape = (BATCH_SIZE_PER_CORE , embedding_output_dim3 ),
742+ input_shape = (self . batch_size , 1 ),
743+ output_shape = (self . batch_size , embedding_output_dim3 ),
736744 ),
737745 }
738746
739- batch_size = self ._strategy .num_replicas_in_sync * BATCH_SIZE_PER_CORE
740747 inputs , _ , _ = self .create_inputs_weights_and_labels (
741- batch_size , "dense" , embedding_config
748+ "dense" , embedding_config
742749 )
743750
744751 with self ._strategy .scope ():
@@ -747,20 +754,19 @@ def test_mixed_placement(self):
747754 res = self .run_with_strategy (layer .__call__ , inputs )
748755
749756 self .assertEqual (
750- res ["feature1" ].shape , (batch_size , embedding_output_dim1 )
757+ res ["feature1" ].shape , (self . batch_size , embedding_output_dim1 )
751758 )
752759 self .assertEqual (
753- res ["feature2" ].shape , (batch_size , embedding_output_dim2 )
760+ res ["feature2" ].shape , (self . batch_size , embedding_output_dim2 )
754761 )
755762 self .assertEqual (
756- res ["feature3" ].shape , (batch_size , embedding_output_dim3 )
763+ res ["feature3" ].shape , (self . batch_size , embedding_output_dim3 )
757764 )
758765
759766 def test_save_load_model (self ):
760- batch_size = self ._strategy .num_replicas_in_sync * BATCH_SIZE_PER_CORE
761767 feature_configs = self .get_embedding_config ("dense" , self .placement )
762768 inputs , _ , _ = self .create_inputs_weights_and_labels (
763- batch_size , "dense" , feature_configs
769+ "dense" , feature_configs
764770 )
765771
766772 keras_inputs = keras .tree .map_structure (
0 commit comments