Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
CPUMemoryStats,
GPUMemoryStats,
)
from torchrec.distributed.benchmark.benchmark_utils import (
from torchrec.distributed.test_utils.input_config import ModelInputConfig
from torchrec.distributed.test_utils.model_config import (
BaseModelConfig,
create_model_config,
generate_sharded_model_and_optimizer,
)
from torchrec.distributed.test_utils.input_config import ModelInputConfig
from torchrec.distributed.test_utils.model_input import ModelInput

from torchrec.distributed.test_utils.multi_process import (
Expand Down
15 changes: 7 additions & 8 deletions torchrec/distributed/benchmark/embedding_collection_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
import torch
from torch import multiprocessing as mp
from torchrec.distributed import DistributedModelParallel

from torchrec.distributed.benchmark.base import (
benchmark_model_with_warmup,
BenchmarkResult,
CompileMode,
multi_process_benchmark,
)
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.global_settings import set_propogate_device
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
Expand All @@ -56,14 +63,6 @@
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor

# Import the shared types and utilities from benchmark_utils
from .base import (
benchmark_model_with_warmup,
BenchmarkResult,
CompileMode,
multi_process_benchmark,
)

logger: logging.Logger = logging.getLogger()

T = TypeVar("T", bound=torch.nn.Module)
Expand Down
36 changes: 36 additions & 0 deletions torchrec/distributed/benchmark/yaml/sparse_data_dist_emo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# this is a very basic sparse data dist config
# runs on 2 ranks, showing traces with reasonable workloads
RunOptions:
world_size: 2
num_batches: 5
num_benchmarks: 2
sharding_type: table_wise
profile_dir: "."
name: "sparse_data_dist_base"
# export_stacks: True # enable this to export stack traces
PipelineConfig:
pipeline: "sparse"
EmbeddingTablesConfig:
num_unweighted_features: 100
num_weighted_features: 100
embedding_feature_dim: 256
additional_tables:
- - name: FP16_table
embedding_dim: 512
num_embeddings: 100_000
feature_names: ["additional_0_0"]
data_type: FP16
- name: large_table
embedding_dim: 256
num_embeddings: 1_000_000
feature_names: ["additional_0_1"]
- []
- - name: skipped_table
embedding_dim: 128
num_embeddings: 100_000
feature_names: ["additional_2_1"]
PlannerConfig:
additional_constraints:
large_table:
compute_kernels: [fused_uvm_caching]
sharding_types: [row_wise]
9 changes: 3 additions & 6 deletions torchrec/distributed/test_utils/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

# pyre-strict

from dataclasses import dataclass, fields
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
from dataclasses import dataclass
from typing import List, Optional

import torch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -40,10 +40,7 @@ def generate_batches(
Generate model input data for benchmarking.
Args:
tables: List of unweighted embedding tables
weighted_tables: List of weighted embedding tables
model_config: Configuration for model generation
num_batches: Number of batches to generate
tables: List of embedding tables
Returns:
A list of ModelInput objects representing the generated batches
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,15 @@ class BaseModelConfig(ABC):
"""

# Common parameters for all model types
batch_size: int
batch_sizes: Optional[List[int]]
num_float_features: int
feature_pooling_avg: int
use_offsets: bool
dev_str: str
long_kjt_indices: bool
long_kjt_offsets: bool
long_kjt_lengths: bool
pin_memory: bool
num_float_features: int # we assume all model arch has a single dense feature layer

@abstractmethod
def generate_model(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
**kwargs: Any,
) -> nn.Module:
"""
Generate a model instance based on the configuration.
Expand Down Expand Up @@ -100,6 +92,7 @@ def generate_model(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
**kwargs: Any,
) -> nn.Module:
return TestSparseNN(
tables=tables,
Expand Down Expand Up @@ -128,6 +121,7 @@ def generate_model(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
**kwargs: Any,
) -> nn.Module:
return TestTowerSparseNN(
num_float_features=self.num_float_features,
Expand All @@ -152,6 +146,7 @@ def generate_model(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
**kwargs: Any,
) -> nn.Module:
return TestTowerCollectionSparseNN(
tables=tables,
Expand All @@ -176,6 +171,7 @@ def generate_model(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
**kwargs: Any,
) -> nn.Module:
# DeepFM only uses unweighted tables
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
Expand All @@ -201,6 +197,7 @@ def generate_model(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
**kwargs: Any,
) -> nn.Module:
# DLRM only uses unweighted tables
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
Expand Down
8 changes: 0 additions & 8 deletions torchrec/distributed/test_utils/sharding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,7 @@ def generate_planner(
Generate an embedding sharding planner based on the specified configuration.
Args:
planner_type: Type of planner to use ("embedding" or "hetero")
topology: Network topology for distributed training
tables: List of unweighted embedding tables
weighted_tables: List of weighted embedding tables
sharding_type: Strategy for sharding embedding tables
compute_kernel: Compute kernel to use for embedding tables
batch_sizes: Sizes of each batch
pooling_factors: Pooling factors for each feature of the table
num_poolings: Number of poolings for each feature of the table
Returns:
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner
Expand Down
Loading