Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def shard_node_given_placements(node, sharding_placement, *, meta: bool):
mesh = tgt_spec.mesh
# all tensors start as replicated
curr_placement = (Replicate(),) * mesh.ndim
if "val" not in node.meta:
# for non-tensor inputs, they are considered as being
# baked in the graph, so we don't need to do anything
# and just return a dummy value
assert len(node.users) == 0
return "arbitrary value"
tensor = node.meta["val"]

ctx: Any
Expand Down Expand Up @@ -303,7 +309,7 @@ def _get_inductor_decomp_table():

def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
args = shard_nodes_given_placements(gm, sharding_placement)
local_args = [arg.to_local() for arg in args]
local_args = tree_map_only(DTensor, lambda x: x.to_local(), args)

decomp_table = _get_inductor_decomp_table()
# run with DTensor to apply the collectives given the graph
Expand Down
9 changes: 7 additions & 2 deletions autoparallel/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,21 @@ def update_joint_with_descriptors(
"""
# TODO: should we upstream a util like this?
placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"]
new_local_args = [n.meta["val"] for n in placeholders]
# assume if "val" is not present in meta, then it's a non-tensor input
# and there is no sharding associated with it and we can just forward
# the original input
new_local_args = [n.meta.get("val", None) for n in placeholders]
joint_with_descriptors.graph_module = updated_gm
joint_with_descriptors._aot_graph_capture.graph_module = updated_gm

new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = []
for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args):
if isinstance(orig, torch.nn.Parameter):
new_flat_args.append(torch.nn.Parameter(new))
else:
elif new is not None:
new_flat_args.append(new)
else:
new_flat_args.append(orig)

tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
new_local_tangents = new_local_args[tangent_idx:]
Expand Down
27 changes: 24 additions & 3 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ def build_sharding_metadata(self):
strats = {}
for node in self.graph.nodes:
if node.op == "placeholder":
if node.meta.get("val", None) is None:
# For non-tensor inputs, they are considered as being
# replicated across all ranks. Given that those inputs
# seems to have been baked into the graph, we don't
# actually will use this OpStrategy
strats[node] = _create_all_options(self.mesh, ())
# for now, seems like non-tensor inputs are baked in the graph
# so let's assert that this is indeed the case
assert (
len(node.users) == 0
), f"{node} nas {len(node.users)}, expected 0"
continue
strats[node] = _create_all_options(
self.mesh, node.meta["val"].shape, tensor=node.meta["val"]
)
Expand Down Expand Up @@ -828,16 +840,25 @@ def add_sharded_input_constraint(
if input_placements is not None:
mut_ips = {i: p for i, p in enumerate(input_placements)}

for desc, (node, grad_node) in get_plain_input_and_grad_nodes(
self.graph
).items():
inputs_and_grads = get_plain_input_and_grad_nodes(self.graph)
if mut_ips is not None and len(mut_ips) != len(inputs_and_grads):
raise ValueError(
f"Expected to have {len(inputs_and_grads)} "
f"input placements, got {len(mut_ips)}"
)

for desc, (node, grad_node) in inputs_and_grads.items():
if input_placements is None:
placement = None
else:
assert isinstance(desc, PlainAOTInput)
assert mut_ips is not None
placement = mut_ips.pop(desc.idx)

if placement is None and "val" not in node.meta:
# this is a non-tensor input, we don't do anything about it
continue

self.add_node_constraint(
node, placement, constraint_name="input_constraint"
)
Expand Down
25 changes: 0 additions & 25 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,31 +646,6 @@ def convert_element_type_rule(mesh, op_schema):
return out_strat


@register_opschema_rule(torch.ops.aten.split.Tensor)
def split_rule(mesh, op_schema):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.split.Tensor has been implemented upstream in pytorch/pytorch#149106

strat = op_schema.args_schema
op = torch.ops.aten.split.Tensor
from torch.distributed.tensor._ops._tensor_ops import split_rule

res = []
oo = []
for i, ss in enumerate(strat[0].strategies):
ispec = ss.input_specs[0]
assert ss.output_spec == ispec
o = split_rule(OpSchema(op, (ispec, strat[1], strat[2]), {}))
# res.append(o)
oo.append(o)
if o.output_spec is not None:
s = OpSpec(o.output_spec, input_specs=(ispec,))
s.redistribute_cost = [[math.inf] * len(ss.redistribute_cost[0])]
# s.redistribute_cost = [[0.0] * len(ss.redistribute_cost[0])]
s.redistribute_cost[0][i] = 0.0
res.append(s)

out_strat = OpStrategy(res)
return out_strat


@register_opschema_rule(torch.ops.aten._unsafe_index.Tensor)
def _unsafe_index_rule(mesh, op_schema):
raise NotImplementedError()
Expand Down
45 changes: 45 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,51 @@ def input_fn():
)


def test_non_tensor_input(device_mesh_1d):
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.linear = nn.Linear(dim, dim)

def forward(self, x, input_dim: int):
return self.linear(x).chunk(2, dim=input_dim)

def init_weights(self):
dim = self.dim
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
with torch.no_grad():
self.linear.bias.fill_(98.6)

def input_fn():
b = 512
inputs = torch.rand(b, dim, device="cuda")
input_dim = 1
return (inputs, input_dim)

with torch.device("meta"):
model = Model(dim)
with AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding, None])
sharding_placement = autop.optimize_placement()

parallel_mod = autop.apply_placement(sharding_placement)
parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()
placeholders = autop.gm.graph.find_nodes(op="placeholder")
non_tensor_input = placeholders[3]
assert sharding_placement[non_tensor_input].output_specs.placements == (
Replicate(),
)


def test_fx_graph_annotate(device_mesh_1d):
dim = 128

Expand Down
Loading