Skip to content

Commit d57e334

Browse files
nipung90facebook-github-bot
authored andcommitted
Enable logging for the plan() function, ShardEstimators and TrainingPipeline class constructors
Differential Revision: D87488015
1 parent 32e5431 commit d57e334

File tree

6 files changed

+16
-0
lines changed

6 files changed

+16
-0
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
FUSED_PARAM_IS_SSD_TABLE,
5656
FUSED_PARAM_SSD_TABLE_LIST,
5757
)
58+
from torchrec.distributed.logger import _torchrec_method_logger
5859
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5960
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
6061
from torchrec.distributed.sharding.dynamic_sharding import (
@@ -466,6 +467,7 @@ class ShardedEmbeddingBagCollection(
466467
This is part of the public API to allow for manual data dist pipelining.
467468
"""
468469

470+
@_torchrec_method_logger()
469471
def __init__(
470472
self,
471473
module: EmbeddingBagCollectionInterface,
@@ -2021,6 +2023,7 @@ class ShardedEmbeddingBag(
20212023
This is part of the public API to allow for manual data dist pipelining.
20222024
"""
20232025

2026+
@_torchrec_method_logger()
20242027
def __init__(
20252028
self,
20262029
module: nn.EmbeddingBag,

torchrec/distributed/planner/planners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import nn
1919
from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result
2020
from torchrec.distributed.comm import get_local_size
21+
from torchrec.distributed.logger import _torchrec_method_logger
2122
from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE
2223
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
2324
from torchrec.distributed.planner.partitioners import (
@@ -498,6 +499,7 @@ def collective_plan(
498499
sharders,
499500
)
500501

502+
@_torchrec_method_logger()
501503
def plan(
502504
self,
503505
module: nn.Module,

torchrec/distributed/planner/shard_estimators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torchrec.optim as trec_optim
1717
from torch import nn
1818
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
19+
from torchrec.distributed.logger import _torchrec_method_logger
1920
from torchrec.distributed.planner.constants import (
2021
BATCHED_COPY_PERF_FACTOR,
2122
BIGINT_DTYPE,
@@ -955,6 +956,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
955956
is_inference (bool): If the model is inference model. Default to False.
956957
"""
957958

959+
@_torchrec_method_logger()
958960
def __init__(
959961
self,
960962
topology: Topology,

torchrec/distributed/shard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributed._composable.contract import contract
1616
from torchrec.distributed.comm import get_local_size
1717
from torchrec.distributed.global_settings import get_propogate_device
18+
from torchrec.distributed.logger import _torchrec_method_logger
1819
from torchrec.distributed.model_parallel import get_default_sharders
1920
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
2021
from torchrec.distributed.sharding_plan import (
@@ -146,6 +147,7 @@ def _shard(
146147

147148
# pyre-ignore
148149
@contract()
150+
@_torchrec_method_logger()
149151
def shard_modules(
150152
module: nn.Module,
151153
env: Optional[ShardingEnv] = None,
@@ -194,6 +196,7 @@ def init_weights(m):
194196
return _shard_modules(module, env, device, plan, sharders, init_params)
195197

196198

199+
@_torchrec_method_logger()
197200
def _shard_modules( # noqa: C901
198201
module: nn.Module,
199202
# TODO: Consolidate to using Dict[str, ShardingEnv]

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import torch
3333
from torch.autograd.profiler import record_function
3434
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
35+
from torchrec.distributed.logger import _torchrec_method_logger
3536
from torchrec.distributed.model_parallel import ShardedModule
3637
from torchrec.distributed.train_pipeline.pipeline_context import (
3738
EmbeddingTrainPipelineContext,
@@ -106,6 +107,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
106107
def progress(self, dataloader_iter: Iterator[In]) -> Out:
107108
pass
108109

110+
# pyre-ignore [56]
111+
@_torchrec_method_logger()
109112
def __init__(self) -> None:
110113
# pipeline state such as in foward, in backward etc, used in training recover scenarios
111114
self._state: PipelineState = PipelineState.IDLE

torchrec/modules/mc_embedding_modules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn as nn
15+
from torchrec.distributed.logger import _torchrec_method_logger
1516

1617
from torchrec.modules.embedding_modules import (
1718
EmbeddingBagCollection,
@@ -125,6 +126,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio
125126
126127
"""
127128

129+
@_torchrec_method_logger()
128130
def __init__(
129131
self,
130132
embedding_collection: EmbeddingCollection,
@@ -164,6 +166,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec
164166
165167
"""
166168

169+
@_torchrec_method_logger()
167170
def __init__(
168171
self,
169172
embedding_bag_collection: EmbeddingBagCollection,

0 commit comments

Comments
 (0)