From 10ac9b5aa65b00c603ac5daa001750d405d44675 Mon Sep 17 00:00:00 2001 From: Aditya Venkataraman Date: Mon, 24 Nov 2025 19:38:33 -0800 Subject: [PATCH] Don't crash if tensor_meta is not available for output spec PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legtimiately be None depending on inputs (e.g., convolution.backward with certain output_mask). Switch validation to throw a warning in such case. If tensor_meta is legtimately not known, and the output of an op is subsequently an input to a downstream op, we will fail during the input_spec validation. Testing: Adding convolution test that revealed this issue. --- autoparallel/optimize_sharding.py | 18 ++++++--- autoparallel/propagation_rules.py | 2 + tests/test_api.py | 61 +++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 5e1dca5..25269f2 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -77,6 +77,7 @@ runtime cost while satisfying all constraints. """ +import logging import math import operator import time @@ -105,6 +106,8 @@ from .propagation_rules import _create_all_options from .utils import get_local_map_placement_option, get_placement_options +logger = logging.getLogger(__name__) + def _debug_node(node): def my_print(x): @@ -924,11 +927,16 @@ def validate(self): ), f"{node}, {len(strat0.redistribute_cost)}, {num_input_nodes}" ospec = strat0.output_specs if isinstance(ospec, (list, tuple)): - for spec in ospec: - if spec: - assert ( - spec.tensor_meta is not None - ), f"{node} doesn't have a tensor_meta" + for i, spec in enumerate(ospec): + if spec is not None: + if spec.tensor_meta is None: + # PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. + # In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but + # for some ops, one or more outputs can legtimiately be None depending on inputs + # (e.g., convolution.backward with certain output_mask). + logger.warning( + f"Output #{i} spec {spec} for `{node}` has no tensor_meta" + ) elif ospec is None: # e.g. getitem on scalars pass diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 4a8e52c..e18b971 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -120,6 +120,8 @@ def remove_invalid_configs(out_strat, mesh): for spec in specs: if spec is None: continue + if spec.tensor_meta is None: + continue shape = list(spec.tensor_meta.shape) for mesh_shape, plc in zip(mesh.shape, spec.placements): if plc.is_shard(): diff --git a/tests/test_api.py b/tests/test_api.py index 29e010f..4ca83f8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -358,3 +358,64 @@ def input_fn(): ] # Should only have 2 placeholders: weight and input (no tangents for inference) assert len(placeholders) == 2 + + +def test_convolution_forward_backward(device_mesh_1d): + """Test that convolution operations work with forward and backward passes. + + Convolution backward has multiple outputs, some of which can be None. In such cases, + we don't want AutoParallel to assert on tensor_meta being populated. + """ + in_channels = 3 + out_channels = 64 + kernel_size = 3 + + class SimpleConvNet(nn.Module): + """Simple network with a single Conv2d operation.""" + + def __init__(self, in_channels, out_channels, kernel_size): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + + def input_fn(): + return torch.rand(256, in_channels, 224, 224, device="cuda") + + # Create model on meta device + with torch.device("meta"): + model = SimpleConvNet(in_channels, out_channels, kernel_size) + + with AutoParallel(model, input_fn, device_mesh_1d, compile=True) as autop: + x_sharding = (Replicate(),) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + + sharding_placement = autop.optimize_placement() + parallel_mod = autop.apply_placement(sharding_placement) + + # Verify the model was created successfully + assert parallel_mod is not None + assert hasattr(autop, "parallel_gm") + + # Verify that both convolution.default and convolution_backward.default + # are in the parallel graph + parallel_graph_ops = { + n.target for n in autop.parallel_gm.graph.nodes if n.op == "call_function" + } + assert torch.ops.aten.convolution.default in parallel_graph_ops + assert torch.ops.aten.convolution_backward.default in parallel_graph_ops + + # Verify sharding strategies were computed for all nodes + for node in autop.gm.graph.nodes: + if node.op == "call_function": + assert node in autop.sharding_optimizer.strats + strats = autop.sharding_optimizer.strats[node] + assert len(strats.strategies) > 0, f"No strategies for {node}"