Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def shard_node_given_placements(node, sharding_placement, *, meta: bool):
mesh = tgt_spec.mesh
# all tensors start as replicated
curr_placement = (Replicate(),) * mesh.ndim
if "val" not in node.meta:
# for non-tensor inputs, they are considered as being
# baked in the graph, so we don't need to do anything
# and just return a dummy value
assert len(node.users) == 0
return "arbitrary value"
tensor = node.meta["val"]

ctx: Any
Expand Down Expand Up @@ -303,7 +309,7 @@ def _get_inductor_decomp_table():

def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
args = shard_nodes_given_placements(gm, sharding_placement)
local_args = [arg.to_local() for arg in args]
local_args = tree_map_only(DTensor, lambda x: x.to_local(), args)

decomp_table = _get_inductor_decomp_table()
# run with DTensor to apply the collectives given the graph
Expand Down
4 changes: 4 additions & 0 deletions autoparallel/cast_parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def apply_dtype_cast(model, mp_policy: MixedPrecisionPolicy):
class DTypeCastModule(torch.nn.Module):
def forward(self, *args, **kwargs):
def cast_fn(x):
if not isinstance(x, torch.Tensor):
return x
if not torch.is_floating_point(x):
return x
return x.to(self._mp_policy.param_dtype)
Expand All @@ -196,6 +198,8 @@ def cast_fn(x):
output = super().forward(*args, **kwargs)

def cast_out_fn(x):
if not isinstance(x, torch.Tensor):
return x
return x.to(self._mp_policy.output_dtype)

output = tree_map(cast_out_fn, output)
Expand Down
9 changes: 7 additions & 2 deletions autoparallel/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,21 @@ def update_joint_with_descriptors(
"""
# TODO: should we upstream a util like this?
placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"]
new_local_args = [n.meta["val"] for n in placeholders]
# assume if "val" is not present in meta, then it's a non-tensor input
# and there is no sharding associated with it and we can just forward
# the original input
new_local_args = [n.meta.get("val", None) for n in placeholders]
joint_with_descriptors.graph_module = updated_gm
joint_with_descriptors._aot_graph_capture.graph_module = updated_gm

new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = []
for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args):
if isinstance(orig, torch.nn.Parameter):
new_flat_args.append(torch.nn.Parameter(new))
else:
elif new is not None:
new_flat_args.append(new)
else:
new_flat_args.append(orig)

tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
new_local_tangents = new_local_args[tangent_idx:]
Expand Down
27 changes: 24 additions & 3 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ def build_sharding_metadata(self):
strats = {}
for node in self.graph.nodes:
if node.op == "placeholder":
if node.meta.get("val", None) is None:
# For non-tensor inputs, they are considered as being
# replicated across all ranks. Given that those inputs
# seems to have been baked into the graph, we don't
# actually will use this OpStrategy
strats[node] = _create_all_options(self.mesh, ())
# for now, seems like non-tensor inputs are baked in the graph
# so let's assert that this is indeed the case
assert (
len(node.users) == 0
), f"{node} nas {len(node.users)}, expected 0"
continue
strats[node] = _create_all_options(
self.mesh, node.meta["val"].shape, tensor=node.meta["val"]
)
Expand Down Expand Up @@ -828,16 +840,25 @@ def add_sharded_input_constraint(
if input_placements is not None:
mut_ips = {i: p for i, p in enumerate(input_placements)}

for desc, (node, grad_node) in get_plain_input_and_grad_nodes(
self.graph
).items():
inputs_and_grads = get_plain_input_and_grad_nodes(self.graph)
if mut_ips is not None and len(mut_ips) != len(inputs_and_grads):
raise ValueError(
f"Expected to have {len(inputs_and_grads)} "
f"input placements, got {len(mut_ips)}"
)

for desc, (node, grad_node) in inputs_and_grads.items():
if input_placements is None:
placement = None
else:
assert isinstance(desc, PlainAOTInput)
assert mut_ips is not None
placement = mut_ips.pop(desc.idx)

if placement is None and "val" not in node.meta:
# this is a non-tensor input, we don't do anything about it
continue

self.add_node_constraint(
node, placement, constraint_name="input_constraint"
)
Expand Down
25 changes: 0 additions & 25 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,31 +646,6 @@ def convert_element_type_rule(mesh, op_schema):
return out_strat


@register_opschema_rule(torch.ops.aten.split.Tensor)
def split_rule(mesh, op_schema):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.split.Tensor has been implemented upstream in pytorch/pytorch#149106

strat = op_schema.args_schema
op = torch.ops.aten.split.Tensor
from torch.distributed.tensor._ops._tensor_ops import split_rule

res = []
oo = []
for i, ss in enumerate(strat[0].strategies):
ispec = ss.input_specs[0]
assert ss.output_spec == ispec
o = split_rule(OpSchema(op, (ispec, strat[1], strat[2]), {}))
# res.append(o)
oo.append(o)
if o.output_spec is not None:
s = OpSpec(o.output_spec, input_specs=(ispec,))
s.redistribute_cost = [[math.inf] * len(ss.redistribute_cost[0])]
# s.redistribute_cost = [[0.0] * len(ss.redistribute_cost[0])]
s.redistribute_cost[0][i] = 0.0
res.append(s)

out_strat = OpStrategy(res)
return out_strat


@register_opschema_rule(torch.ops.aten._unsafe_index.Tensor)
def _unsafe_index_rule(mesh, op_schema):
raise NotImplementedError()
Expand Down
45 changes: 45 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,51 @@ def input_fn():
)


def test_non_tensor_input(device_mesh_1d):
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.linear = nn.Linear(dim, dim)

def forward(self, x, input_dim: int):
return self.linear(x).chunk(2, dim=input_dim)

def init_weights(self):
dim = self.dim
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
with torch.no_grad():
self.linear.bias.fill_(98.6)

def input_fn():
b = 512
inputs = torch.rand(b, dim, device="cuda")
input_dim = 1
return (inputs, input_dim)

with torch.device("meta"):
model = Model(dim)
with AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding, None])
sharding_placement = autop.optimize_placement()

parallel_mod = autop.apply_placement(sharding_placement)
parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()
placeholders = autop.gm.graph.find_nodes(op="placeholder")
non_tensor_input = placeholders[3]
assert sharding_placement[non_tensor_input].output_specs.placements == (
Replicate(),
)


def test_fx_graph_annotate(device_mesh_1d):
dim = 128

Expand Down