Skip to content

Commit 8f59580

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
model config for test_utils and benchmark (#3450)
Summary: Pull Request resolved: #3450 # context * move benchmark/test model constructors to test_utils.model_config.py * add EMO (embedding offloading) yaml config for benchark * result [trace]() shows LRU caching {F1982639616} Reviewed By: spmex Differential Revision: D84325828 fbshipit-source-id: 0e1cb4e224c5e9bd68bc66604dee1dd39b3fd296
1 parent 1371bfc commit 8f59580

File tree

6 files changed

+55
-34
lines changed

6 files changed

+55
-34
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
CPUMemoryStats,
3535
GPUMemoryStats,
3636
)
37-
from torchrec.distributed.benchmark.benchmark_utils import (
37+
from torchrec.distributed.test_utils.input_config import ModelInputConfig
38+
from torchrec.distributed.test_utils.model_config import (
3839
BaseModelConfig,
3940
create_model_config,
4041
generate_sharded_model_and_optimizer,
4142
)
42-
from torchrec.distributed.test_utils.input_config import ModelInputConfig
4343
from torchrec.distributed.test_utils.model_input import ModelInput
4444

4545
from torchrec.distributed.test_utils.multi_process import (

torchrec/distributed/benchmark/embedding_collection_wrappers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@
3636
import torch
3737
from torch import multiprocessing as mp
3838
from torchrec.distributed import DistributedModelParallel
39+
40+
from torchrec.distributed.benchmark.base import (
41+
benchmark_model_with_warmup,
42+
BenchmarkResult,
43+
CompileMode,
44+
multi_process_benchmark,
45+
)
3946
from torchrec.distributed.embedding_types import ShardingType
4047
from torchrec.distributed.global_settings import set_propogate_device
4148
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
@@ -56,14 +63,6 @@
5663
)
5764
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
5865

59-
# Import the shared types and utilities from benchmark_utils
60-
from .base import (
61-
benchmark_model_with_warmup,
62-
BenchmarkResult,
63-
CompileMode,
64-
multi_process_benchmark,
65-
)
66-
6766
logger: logging.Logger = logging.getLogger()
6867

6968
T = TypeVar("T", bound=torch.nn.Module)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# this is a very basic sparse data dist config
2+
# runs on 2 ranks, showing traces with reasonable workloads
3+
RunOptions:
4+
world_size: 2
5+
num_batches: 5
6+
num_benchmarks: 2
7+
sharding_type: table_wise
8+
profile_dir: "."
9+
name: "sparse_data_dist_base"
10+
# export_stacks: True # enable this to export stack traces
11+
PipelineConfig:
12+
pipeline: "sparse"
13+
EmbeddingTablesConfig:
14+
num_unweighted_features: 100
15+
num_weighted_features: 100
16+
embedding_feature_dim: 256
17+
additional_tables:
18+
- - name: FP16_table
19+
embedding_dim: 512
20+
num_embeddings: 100_000
21+
feature_names: ["additional_0_0"]
22+
data_type: FP16
23+
- name: large_table
24+
embedding_dim: 256
25+
num_embeddings: 1_000_000
26+
feature_names: ["additional_0_1"]
27+
- []
28+
- - name: skipped_table
29+
embedding_dim: 128
30+
num_embeddings: 100_000
31+
feature_names: ["additional_2_1"]
32+
PlannerConfig:
33+
additional_constraints:
34+
large_table:
35+
compute_kernels: [fused_uvm_caching]
36+
sharding_types: [row_wise]

torchrec/distributed/test_utils/input_config.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
# pyre-strict
99

10-
from dataclasses import dataclass, fields
11-
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
10+
from dataclasses import dataclass
11+
from typing import List, Optional
1212

1313
import torch
1414
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -40,10 +40,7 @@ def generate_batches(
4040
Generate model input data for benchmarking.
4141
4242
Args:
43-
tables: List of unweighted embedding tables
44-
weighted_tables: List of weighted embedding tables
45-
model_config: Configuration for model generation
46-
num_batches: Number of batches to generate
43+
tables: List of embedding tables
4744
4845
Returns:
4946
A list of ModelInput objects representing the generated batches

torchrec/distributed/benchmark/benchmark_utils.py renamed to torchrec/distributed/test_utils/model_config.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,15 @@ class BaseModelConfig(ABC):
5252
"""
5353

5454
# Common parameters for all model types
55-
batch_size: int
56-
batch_sizes: Optional[List[int]]
57-
num_float_features: int
58-
feature_pooling_avg: int
59-
use_offsets: bool
60-
dev_str: str
61-
long_kjt_indices: bool
62-
long_kjt_offsets: bool
63-
long_kjt_lengths: bool
64-
pin_memory: bool
55+
num_float_features: int # we assume all model arch has a single dense feature layer
6556

6657
@abstractmethod
6758
def generate_model(
6859
self,
6960
tables: List[EmbeddingBagConfig],
7061
weighted_tables: List[EmbeddingBagConfig],
7162
dense_device: torch.device,
63+
**kwargs: Any,
7264
) -> nn.Module:
7365
"""
7466
Generate a model instance based on the configuration.
@@ -100,6 +92,7 @@ def generate_model(
10092
tables: List[EmbeddingBagConfig],
10193
weighted_tables: List[EmbeddingBagConfig],
10294
dense_device: torch.device,
95+
**kwargs: Any,
10396
) -> nn.Module:
10497
return TestSparseNN(
10598
tables=tables,
@@ -128,6 +121,7 @@ def generate_model(
128121
tables: List[EmbeddingBagConfig],
129122
weighted_tables: List[EmbeddingBagConfig],
130123
dense_device: torch.device,
124+
**kwargs: Any,
131125
) -> nn.Module:
132126
return TestTowerSparseNN(
133127
num_float_features=self.num_float_features,
@@ -152,6 +146,7 @@ def generate_model(
152146
tables: List[EmbeddingBagConfig],
153147
weighted_tables: List[EmbeddingBagConfig],
154148
dense_device: torch.device,
149+
**kwargs: Any,
155150
) -> nn.Module:
156151
return TestTowerCollectionSparseNN(
157152
tables=tables,
@@ -176,6 +171,7 @@ def generate_model(
176171
tables: List[EmbeddingBagConfig],
177172
weighted_tables: List[EmbeddingBagConfig],
178173
dense_device: torch.device,
174+
**kwargs: Any,
179175
) -> nn.Module:
180176
# DeepFM only uses unweighted tables
181177
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
@@ -201,6 +197,7 @@ def generate_model(
201197
tables: List[EmbeddingBagConfig],
202198
weighted_tables: List[EmbeddingBagConfig],
203199
dense_device: torch.device,
200+
**kwargs: Any,
204201
) -> nn.Module:
205202
# DLRM only uses unweighted tables
206203
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))

torchrec/distributed/test_utils/sharding_config.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,7 @@ def generate_planner(
7575
Generate an embedding sharding planner based on the specified configuration.
7676
7777
Args:
78-
planner_type: Type of planner to use ("embedding" or "hetero")
79-
topology: Network topology for distributed training
8078
tables: List of unweighted embedding tables
81-
weighted_tables: List of weighted embedding tables
82-
sharding_type: Strategy for sharding embedding tables
83-
compute_kernel: Compute kernel to use for embedding tables
84-
batch_sizes: Sizes of each batch
85-
pooling_factors: Pooling factors for each feature of the table
86-
num_poolings: Number of poolings for each feature of the table
8779
8880
Returns:
8981
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner

0 commit comments

Comments
 (0)