@@ -177,6 +177,10 @@ def compute(
177177 ctx : ITEPEmbeddingBagCollectionContext ,
178178 dist_input : KJTList ,
179179 ) -> List [torch .Tensor ]:
180+ # We need to explicitly move iter to CPU since it might be moved to GPU
181+ # after __init__. This should be done once.
182+ self ._iter = self ._iter .cpu ()
183+
180184 if not ctx .is_reindexed :
181185 dist_input = self ._reindex (dist_input )
182186 ctx .is_reindexed = True
@@ -196,6 +200,10 @@ def output_dist(
196200 def compute_and_output_dist (
197201 self , ctx : ITEPEmbeddingBagCollectionContext , input : KJTList
198202 ) -> LazyAwaitable [KeyedTensor ]:
203+ # We need to explicitly move iter to CPU since it might be moved to GPU
204+ # after __init__. This should be done once.
205+ self ._iter = self ._iter .cpu ()
206+
199207 # Insert forward() function of GenericITEPModule into compute_and_output_dist()
200208 for i , (sharding , features ) in enumerate (
201209 zip (
@@ -424,6 +432,10 @@ def compute(
424432 ctx : ITEPEmbeddingCollectionContext ,
425433 dist_input : KJTList ,
426434 ) -> List [torch .Tensor ]:
435+ # We need to explicitly move iter to CPU since it might be moved to GPU
436+ # after __init__. This should be done once.
437+ self ._iter = self ._iter .cpu ()
438+
427439 for i , (sharding , features ) in enumerate (
428440 zip (
429441 self ._embedding_collection ._sharding_type_to_sharding .keys (),
@@ -450,6 +462,10 @@ def output_dist(
450462 def compute_and_output_dist (
451463 self , ctx : ITEPEmbeddingCollectionContext , input : KJTList
452464 ) -> LazyAwaitable [Dict [str , JaggedTensor ]]:
465+ # We need to explicitly move iter to CPU since it might be moved to GPU
466+ # after __init__. This should be done once.
467+ self ._iter = self ._iter .cpu ()
468+
453469 # Insert forward() function of GenericITEPModule into compute_and_output_dist()
454470 """ """
455471 for i , (sharding , features ) in enumerate (
0 commit comments