Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def compute(
ctx: ITEPEmbeddingBagCollectionContext,
dist_input: KJTList,
) -> List[torch.Tensor]:
# We need to explicitly move iter to CPU since it might be moved to GPU
# after __init__. This should be done once.
self._iter = self._iter.cpu()

if not ctx.is_reindexed:
dist_input = self._reindex(dist_input)
ctx.is_reindexed = True
Expand All @@ -196,6 +200,10 @@ def output_dist(
def compute_and_output_dist(
self, ctx: ITEPEmbeddingBagCollectionContext, input: KJTList
) -> LazyAwaitable[KeyedTensor]:
# We need to explicitly move iter to CPU since it might be moved to GPU
# after __init__. This should be done once.
self._iter = self._iter.cpu()

# Insert forward() function of GenericITEPModule into compute_and_output_dist()
for i, (sharding, features) in enumerate(
zip(
Expand Down Expand Up @@ -424,6 +432,10 @@ def compute(
ctx: ITEPEmbeddingCollectionContext,
dist_input: KJTList,
) -> List[torch.Tensor]:
# We need to explicitly move iter to CPU since it might be moved to GPU
# after __init__. This should be done once.
self._iter = self._iter.cpu()

for i, (sharding, features) in enumerate(
zip(
self._embedding_collection._sharding_type_to_sharding.keys(),
Expand All @@ -450,6 +462,10 @@ def output_dist(
def compute_and_output_dist(
self, ctx: ITEPEmbeddingCollectionContext, input: KJTList
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
# We need to explicitly move iter to CPU since it might be moved to GPU
# after __init__. This should be done once.
self._iter = self._iter.cpu()

# Insert forward() function of GenericITEPModule into compute_and_output_dist()
""" """
for i, (sharding, features) in enumerate(
Expand Down
1 change: 1 addition & 0 deletions torchrec/modules/itep_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(

features = self._itep_module(features, self._iter.item())
pooled_embeddings = self._embedding_bag_collection(features)

self._iter += 1

return pooled_embeddings
Expand Down
7 changes: 4 additions & 3 deletions torchrec/modules/itep_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,13 @@ def forward(
feature_offsets,
) = self.get_remap_info(sparse_features)

update_utils: bool = (
update_util: bool = (
(cur_iter < 10)
or (cur_iter < 100 and (cur_iter + 1) % 19 == 0)
or ((cur_iter + 1) % 39 == 0)
)
full_values_list = None
if update_utils and sparse_features.variable_stride_per_key():
if update_util and sparse_features.variable_stride_per_key():
if sparse_features.inverse_indices_or_none() is not None:
# full util update mode require reconstructing original input indicies from VBE input
full_values_list = self.get_full_values_list(sparse_features)
Expand All @@ -531,7 +531,7 @@ def forward(
)

remapped_values = torch.ops.fbgemm.remap_indices_update_utils(
cur_iter,
int(cur_iter),
buffer_idx,
feature_lengths,
feature_offsets,
Expand All @@ -540,6 +540,7 @@ def forward(
self.row_util,
self.buffer_offsets,
full_values_list=full_values_list,
update_util=update_util,
)

sparse_features._values = remapped_values
Expand Down
Loading
Loading