Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
runtime cost while satisfying all constraints.
"""

import logging
import math
import operator
import time
Expand Down Expand Up @@ -105,6 +106,8 @@
from .propagation_rules import _create_all_options
from .utils import get_local_map_placement_option, get_placement_options

logger = logging.getLogger(__name__)


def _debug_node(node):
def my_print(x):
Expand Down Expand Up @@ -924,11 +927,16 @@ def validate(self):
), f"{node}, {len(strat0.redistribute_cost)}, {num_input_nodes}"
ospec = strat0.output_specs
if isinstance(ospec, (list, tuple)):
for spec in ospec:
if spec:
assert (
spec.tensor_meta is not None
), f"{node} doesn't have a tensor_meta"
for i, spec in enumerate(ospec):
if spec is not None:
if spec.tensor_meta is None:
# PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta.
# In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but
# for some ops, one or more outputs can legtimiately be None depending on inputs
# (e.g., convolution.backward with certain output_mask).
logger.warning(
f"Output #{i} spec {spec} for `{node}` has no tensor_meta"
)
elif ospec is None:
# e.g. getitem on scalars
pass
Expand Down
2 changes: 2 additions & 0 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def remove_invalid_configs(out_strat, mesh):
for spec in specs:
if spec is None:
continue
if spec.tensor_meta is None:
continue
shape = list(spec.tensor_meta.shape)
for mesh_shape, plc in zip(mesh.shape, spec.placements):
if plc.is_shard():
Expand Down
61 changes: 61 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,64 @@ def input_fn():
]
# Should only have 2 placeholders: weight and input (no tangents for inference)
assert len(placeholders) == 2


def test_convolution_forward_backward(device_mesh_1d):
"""Test that convolution operations work with forward and backward passes.

Convolution backward has multiple outputs, some of which can be None. In such cases,
we don't want AutoParallel to assert on tensor_meta being populated.
"""
in_channels = 3
out_channels = 64
kernel_size = 3

class SimpleConvNet(nn.Module):
"""Simple network with a single Conv2d operation."""

def __init__(self, in_channels, out_channels, kernel_size):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=kernel_size // 2,
bias=False,
)

def forward(self, x):
return self.conv(x)

def input_fn():
return torch.rand(256, in_channels, 224, 224, device="cuda")

# Create model on meta device
with torch.device("meta"):
model = SimpleConvNet(in_channels, out_channels, kernel_size)

with AutoParallel(model, input_fn, device_mesh_1d, compile=True) as autop:
x_sharding = (Replicate(),)
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([x_sharding])

sharding_placement = autop.optimize_placement()
parallel_mod = autop.apply_placement(sharding_placement)

# Verify the model was created successfully
assert parallel_mod is not None
assert hasattr(autop, "parallel_gm")

# Verify that both convolution.default and convolution_backward.default
# are in the parallel graph
parallel_graph_ops = {
n.target for n in autop.parallel_gm.graph.nodes if n.op == "call_function"
}
assert torch.ops.aten.convolution.default in parallel_graph_ops
assert torch.ops.aten.convolution_backward.default in parallel_graph_ops

# Verify sharding strategies were computed for all nodes
for node in autop.gm.graph.nodes:
if node.op == "call_function":
assert node in autop.sharding_optimizer.strats
strats = autop.sharding_optimizer.strats[node]
assert len(strats.strategies) > 0, f"No strategies for {node}"
Loading