diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 8248e4225..a130778fb 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -34,12 +34,12 @@ CPUMemoryStats, GPUMemoryStats, ) -from torchrec.distributed.benchmark.benchmark_utils import ( +from torchrec.distributed.test_utils.input_config import ModelInputConfig +from torchrec.distributed.test_utils.model_config import ( BaseModelConfig, create_model_config, generate_sharded_model_and_optimizer, ) -from torchrec.distributed.test_utils.input_config import ModelInputConfig from torchrec.distributed.test_utils.model_input import ModelInput from torchrec.distributed.test_utils.multi_process import ( diff --git a/torchrec/distributed/benchmark/embedding_collection_wrappers.py b/torchrec/distributed/benchmark/embedding_collection_wrappers.py index 2036c1b07..9549c4652 100644 --- a/torchrec/distributed/benchmark/embedding_collection_wrappers.py +++ b/torchrec/distributed/benchmark/embedding_collection_wrappers.py @@ -36,6 +36,13 @@ import torch from torch import multiprocessing as mp from torchrec.distributed import DistributedModelParallel + +from torchrec.distributed.benchmark.base import ( + benchmark_model_with_warmup, + BenchmarkResult, + CompileMode, + multi_process_benchmark, +) from torchrec.distributed.embedding_types import ShardingType from torchrec.distributed.global_settings import set_propogate_device from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology @@ -56,14 +63,6 @@ ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -# Import the shared types and utilities from benchmark_utils -from .base import ( - benchmark_model_with_warmup, - BenchmarkResult, - CompileMode, - multi_process_benchmark, -) - logger: logging.Logger = logging.getLogger() T = TypeVar("T", bound=torch.nn.Module) diff --git a/torchrec/distributed/benchmark/yaml/sparse_data_dist_emo.yml b/torchrec/distributed/benchmark/yaml/sparse_data_dist_emo.yml new file mode 100644 index 000000000..412f8dc3f --- /dev/null +++ b/torchrec/distributed/benchmark/yaml/sparse_data_dist_emo.yml @@ -0,0 +1,36 @@ +# this is a very basic sparse data dist config +# runs on 2 ranks, showing traces with reasonable workloads +RunOptions: + world_size: 2 + num_batches: 5 + num_benchmarks: 2 + sharding_type: table_wise + profile_dir: "." + name: "sparse_data_dist_base" + # export_stacks: True # enable this to export stack traces +PipelineConfig: + pipeline: "sparse" +EmbeddingTablesConfig: + num_unweighted_features: 100 + num_weighted_features: 100 + embedding_feature_dim: 256 + additional_tables: + - - name: FP16_table + embedding_dim: 512 + num_embeddings: 100_000 + feature_names: ["additional_0_0"] + data_type: FP16 + - name: large_table + embedding_dim: 256 + num_embeddings: 1_000_000 + feature_names: ["additional_0_1"] + - [] + - - name: skipped_table + embedding_dim: 128 + num_embeddings: 100_000 + feature_names: ["additional_2_1"] +PlannerConfig: + additional_constraints: + large_table: + compute_kernels: [fused_uvm_caching] + sharding_types: [row_wise] diff --git a/torchrec/distributed/test_utils/input_config.py b/torchrec/distributed/test_utils/input_config.py index 5beb5e21a..a0683f81f 100644 --- a/torchrec/distributed/test_utils/input_config.py +++ b/torchrec/distributed/test_utils/input_config.py @@ -7,8 +7,8 @@ # pyre-strict -from dataclasses import dataclass, fields -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union +from dataclasses import dataclass +from typing import List, Optional import torch from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -40,10 +40,7 @@ def generate_batches( Generate model input data for benchmarking. Args: - tables: List of unweighted embedding tables - weighted_tables: List of weighted embedding tables - model_config: Configuration for model generation - num_batches: Number of batches to generate + tables: List of embedding tables Returns: A list of ModelInput objects representing the generated batches diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/test_utils/model_config.py similarity index 97% rename from torchrec/distributed/benchmark/benchmark_utils.py rename to torchrec/distributed/test_utils/model_config.py index dee9a9263..f72abcd94 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/test_utils/model_config.py @@ -52,16 +52,7 @@ class BaseModelConfig(ABC): """ # Common parameters for all model types - batch_size: int - batch_sizes: Optional[List[int]] - num_float_features: int - feature_pooling_avg: int - use_offsets: bool - dev_str: str - long_kjt_indices: bool - long_kjt_offsets: bool - long_kjt_lengths: bool - pin_memory: bool + num_float_features: int # we assume all model arch has a single dense feature layer @abstractmethod def generate_model( @@ -69,6 +60,7 @@ def generate_model( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], dense_device: torch.device, + **kwargs: Any, ) -> nn.Module: """ Generate a model instance based on the configuration. @@ -100,6 +92,7 @@ def generate_model( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], dense_device: torch.device, + **kwargs: Any, ) -> nn.Module: return TestSparseNN( tables=tables, @@ -128,6 +121,7 @@ def generate_model( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], dense_device: torch.device, + **kwargs: Any, ) -> nn.Module: return TestTowerSparseNN( num_float_features=self.num_float_features, @@ -152,6 +146,7 @@ def generate_model( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], dense_device: torch.device, + **kwargs: Any, ) -> nn.Module: return TestTowerCollectionSparseNN( tables=tables, @@ -176,6 +171,7 @@ def generate_model( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], dense_device: torch.device, + **kwargs: Any, ) -> nn.Module: # DeepFM only uses unweighted tables ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) @@ -201,6 +197,7 @@ def generate_model( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], dense_device: torch.device, + **kwargs: Any, ) -> nn.Module: # DLRM only uses unweighted tables ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) diff --git a/torchrec/distributed/test_utils/sharding_config.py b/torchrec/distributed/test_utils/sharding_config.py index c7ac11df0..af6c563d6 100644 --- a/torchrec/distributed/test_utils/sharding_config.py +++ b/torchrec/distributed/test_utils/sharding_config.py @@ -75,15 +75,7 @@ def generate_planner( Generate an embedding sharding planner based on the specified configuration. Args: - planner_type: Type of planner to use ("embedding" or "hetero") - topology: Network topology for distributed training tables: List of unweighted embedding tables - weighted_tables: List of weighted embedding tables - sharding_type: Strategy for sharding embedding tables - compute_kernel: Compute kernel to use for embedding tables - batch_sizes: Sizes of each batch - pooling_factors: Pooling factors for each feature of the table - num_poolings: Number of poolings for each feature of the table Returns: An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner