diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index ffbb59e632..3a679bcf13 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -1,5 +1,6 @@ from functools import reduce, partial +from thunder.core.compile_data import using_symbolic_values import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, variableify, unvariableify from thunder.core.pytree import tree_flatten @@ -55,23 +56,36 @@ def _involves_viewed_args(bsym, viewed): return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args) +def _can_be_reshaped(arg, arg_to_replace): + # TODO: Fix this once numel for symbolic values is implemented + if using_symbolic_values(): + arg_numel = arg._numel() + arg_to_replace_numel = arg_to_replace._numel() + else: + arg_numel = arg.numel + arg_to_replace_numel = arg_to_replace.numel + return arg_numel == arg_to_replace_numel + + def replace_args_with_alias_map( computation_trace: Trace, alias_tensor_indices: list[list[int]], -) -> tuple[Trace, dict[VariableInterface, TensorProxy]]: +) -> tuple[Trace, list[set[VariableInterface]]]: if not alias_tensor_indices: - return computation_trace, {} + return computation_trace, [] bsyms: list[BoundSymbol] = [] flat_args, _ = tree_flatten((computation_trace.args, computation_trace.kwargs)) swap_map_for_aliases: dict[VariableInterface, TensorProxy] = {} arg_to_optional_bsyms: dict[VariableInterface, BoundSymbol] = {} + view_groups = {} for indices in alias_tensor_indices: arg = flat_args[indices[0]] for idx in filter(lambda idx: idx < len(flat_args), indices[1:]): arg_to_replace = flat_args[idx] - # Skip aliases with different numel (e.g., complex tensor and its real view) + # Track aliases with different numel (e.g., complex tensor and its real view) # These share storage but have incompatible element counts - if arg.numel != arg_to_replace.numel: + if not _can_be_reshaped(arg, arg_to_replace): + view_groups.setdefault(variableify(arg), []).append(variableify(arg_to_replace)) continue reshaped_arg = arg if arg_to_replace.shape != arg.shape: @@ -111,7 +125,8 @@ def replace_args_with_alias_map( no_implicit_alias_trace.bound_symbols = bsyms str_map = {unvariableify(k).name: v.name for k, v in swap_map_for_aliases.items()} no_implicit_alias_trace.set_provenance(TraceProvenance(f"Duplicate alias args using {str_map}")) - return no_implicit_alias_trace, swap_map_for_aliases + view_groups = [{k}.union(set(v)) for k, v in view_groups.items() if len(v) != 0] + return no_implicit_alias_trace, view_groups def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace: @@ -123,10 +138,10 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li # First pass: identify inputs which are views of each other and swap them out with a default, # reshaping if necessary. - computation_trace, _ = replace_args_with_alias_map(computation_trace, alias_tensor_indices) + computation_trace, view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices) # Second pass: identify views, their originals, and operands involved in inplace ops - view_groups = [] + encountered = set().union(*view_groups) inplace_inputs = set() for bsym in computation_trace.bound_symbols: if _is_inplace_op(bsym) or _is_view_creation_op(bsym): @@ -146,7 +161,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li # filter out view groups that don't have any tensors involved in inplace ops view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0] viewed = set(reduce(set.union, view_groups, set())) - encountered = set() # Third pass: insert alias updates for bsym in computation_trace.bound_symbols: diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 238da88a4f..e90a463a1a 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -517,3 +517,25 @@ def foo(x): expected_grad = torch.autograd.grad(expected, c, g) torch.testing.assert_close(actual_grad_fx, expected_grad) torch.testing.assert_close(actual_grad_jit, expected_grad) + + +@instantiate( + dtypes=(dtypes.float32,), +) +def test_aliasing_for_viewed_input_of_different_shapes(executor, device, dtype): + def f(x, y, z): + return x + 2, y.add_(z) + + a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device) + b = a[0, :] + c = a[1, :] + a_ = a.clone().detach() + b_ = a_[0, :] + c_ = a_[1, :] + jfn = executor.make_callable(f) + actual = jfn(a, b, c) + expected = f(a_, b_, c_) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(a, a_) + torch.testing.assert_close(b, b_) + torch.testing.assert_close(c, c_)