Skip to content
7 changes: 6 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def tag_tensorproxy_output_as_detached(proxy):
return proxy

result = tree_map(tag_tensorproxy_output_as_detached, result)

bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)
symbols_list = trace.peek_scope()

Expand All @@ -350,6 +349,12 @@ def tag_tensorproxy_output_as_detached(proxy):
exception_type=AssertionError,
)

# When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed.
from thunder.core.transform_common import dce

subsymbols = dce(subsymbols, output=result)
bsym = bsym.from_bsym(subsymbols=subsymbols)

symbols_list.append(bsym)
return result

Expand Down
42 changes: 34 additions & 8 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,41 @@ def keep_or_swap(p):
# that only produce non-proxy objects
# NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return
# all the needed proxies of the input trace
def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
def dce(
trace_or_bsyms: Trace | list[BoundSymbolInterface],
needed_proxies: None | set[Variable] = None,
output: Any = None,
) -> Trace | list[BoundSymbolInterface]:
"""Runs a Dead Code Elimination (DCE) pass

Args:
trace_or_bsyms: The trace or list of bound symbols to run the DCE pass on.
needed_proxies: The set of variables to keep.
output: The output of the list of bound symbols. This is only used if the input is a list of bound
symbols, and is required in that case

Returns:
The trace (if the input is a trace) or list of bound symbols (if the input is a list of bound symbols) after the DCE pass.
"""
start_time_ns = time.perf_counter_ns()

producer_map: ProxyDict = producers(trace)
producer_map: ProxyDict = producers(trace_or_bsyms)

flat_trace_outputs, _ = tree_flatten(trace.output)
if isinstance(trace_or_bsyms, Trace):
bound_symbols = trace_or_bsyms.bound_symbols
output = trace_or_bsyms.output
else:
bound_symbols = trace_or_bsyms

flat_trace_outputs, _ = tree_flatten(output)
if needed_proxies is None:
needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)))
else:
needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)))
dced = []

bsym: BoundSymbol
for bsym in reversed(trace.bound_symbols):
for bsym in reversed(bound_symbols):
# Preserves symbols that should never be collected
if has_tags(bsym, {prims.OpTags.DONT_DCE}):
needed = True
Expand All @@ -182,19 +203,24 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
for x in nbsym.flat_proxy_args:
needed_proxies.add(variableify(x))

dcetrace = from_trace(trace)
dced_bound_symbols = list(reversed(dced))
# duplicate number proxies happen with the symbolic shapes and are
# not covered by the above (due to being in tuples?).
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
dcetrace.bound_symbols = dced_bound_symbols

if isinstance(trace_or_bsyms, Trace):
result = from_trace(trace_or_bsyms)
result.bound_symbols = dced_bound_symbols
else:
result = dced_bound_symbols
end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)"))

return dcetrace
if isinstance(trace_or_bsyms, Trace):
result.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)"))

return result


#
Expand Down
15 changes: 9 additions & 6 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,20 @@ def get_backed_value(s):
return tuple(map(get_backed_value, vals))


def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:
def get_proxy_inputs_from_node(node: torch.fx.Node, tracectx) -> tuple[tuple, dict]:
"""Creates proxy inputs from a torch.fx.Node for use with Thunder.

This function generates proxy inputs for a given torch.fx.Node

Args:
node (torch.fx.Node): The FX graph node to create proxy inputs for.
tracectx (TraceCtx): The trace context to use to generate proxy inputs.
"""
import thunder
from thunder.core.trace import TraceCtx
from thunder.core.proxies import proxy

# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):

def make_input_proxy(arg_node):
# This is a Node in the graph representing a Tensor or tuple of Tensors or
Expand Down Expand Up @@ -365,8 +365,10 @@ def _run_with_cache_info():
cache_info["default_dtype"] = torch.get_default_dtype()
cache_info["default_device"] = torch.get_default_device()

tracectx = TraceCtx()

try:
proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node)
proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node, tracectx)
except Exception as e:
return False, SplitReason(
SplitReasonType.EXCEPTION_PROXY_THUNDER_OP,
Expand All @@ -380,7 +382,7 @@ def _run_with_cache_info():
else thunder_symbol
)
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):
try:
function_to_run(*proxy_args, **proxy_kwargs)
Comment on lines -399 to 402
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the source of all of my woes. This pattern implicitly binds the provided proxy_args to the symbols being generated, as opposed to mapping the input args to the input args an established trace expects. But the proxy_args were being created in a distinct TraceCtx, and this result in name collisions. This was revealed in an application of DCE that exists way, way, deep down in the call stack when executing function_to_run. DCE creates a producer map, and in accessing this map, a KeyError was raised, triggering a graph split.

except Exception as e:
Expand Down Expand Up @@ -463,6 +465,7 @@ def is_node_supported_by_thunder(
"""
Determine whether thunder can execute the operation described by this node.
"""
from thunder.core.trace import TraceCtx
# Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node
# Each Node has a function specified by its op property
# Below are the details for the ones this function is interested in -
Expand Down Expand Up @@ -534,7 +537,7 @@ def is_node_supported_by_thunder(
if torchctx.has_method(node.target):
# `torchctx.get_method` requires args and kwargs to resolve which overload of the method is picked.
try:
args, kwargs = get_proxy_inputs_from_node(node)
args, kwargs = get_proxy_inputs_from_node(node, TraceCtx())
except Exception as e:
return False, SplitReason(
SplitReasonType.EXCEPTION_PROXY_THUNDER_OP,
Expand Down
7 changes: 2 additions & 5 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,14 +706,11 @@ def has_cuda_input_or_output(self, bsym: BoundSymbol) -> bool:
return False

def _dce_bsyms(self, input_list, output, bsyms: list[BoundSymbol]) -> list[BoundSymbol]:
trace = TraceCtx(None)
trace.bound_symbols = bsyms
bsyms.append(prims.python_return.bind(output, output=None))
needed_proxies: set[Variable] = set()
trace = dce(trace, needed_proxies)
bsyms = dce(bsyms, needed_proxies, output)
# update the input_list by removing the unused inputs
input_list[:] = [x for x in input_list if variableify(x) in needed_proxies]
return list(filter(lambda x: x.sym != prims.python_return, trace.bound_symbols))
return bsyms

def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol:
sorted_unique_inputs: list[Proxy] = [unvariableify(x) for x in region.inputs]
Expand Down
50 changes: 43 additions & 7 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,36 @@ def func(x, y, device):
assert [t.name for t in tree_flatten(flatten_cse_trace.output)[0]] == ["t4", "t4", "t6", "t14", "t15", "t16", "t17"]


@instantiate(
dtypes=NOTHING,
)
def test_dce(executor, device, _):
def func(x):
dead_code = x + 1 # noqa: F841
y = x * x
return y

x = make_tensor((2, 2), device=device, dtype=torch.float32)
compiled = thunder.jit(func, executors=executor.executors_list())
compiled(x)
traces = thunder.last_traces(compiled)

# find last trace before DCE is applied
for i, trace in enumerate(traces):
provenance = trace.get_provenance().pss if trace.get_provenance() else ""
if "Dead Code Elimination" in provenance:
break
i -= 1
trace = traces[i]

from thunder.core.transform_common import dce

dced_trace = dce(trace)
dced_bsyms = dce(trace.bound_symbols)
assert len(dced_trace.bound_symbols) == len(trace.bound_symbols) - 1
assert len(dced_trace.bound_symbols) == len(dced_bsyms)


def test_symbol_flat_args():
from thunder.core.symbol import Symbol, BoundSymbol

Expand Down Expand Up @@ -3295,22 +3325,28 @@ def clean(tr):


def test_prims_pack_list():
def foo():
pass

trace = TraceCtx(foo)
def foo(x):
a, b = x
return [a, b]

a = torch.randn(2, 2)
b = torch.randn(2, 2)

jfoo = thunder.jit(foo)
jfoo((a, b))

trace = thunder.last_traces(jfoo)[-1]

return_bsym = trace.bound_symbols[-1]
trace.bound_symbols = trace.bound_symbols[:-1]

with tracectx(trace):
x = prims.unpack_trivial(a, name="x")
y = prims.unpack_trivial(b, name="y")
x, y = return_bsym.flat_args
packed_list = prims.pack_list(x, y)
prims.python_return(packed_list)

func = trace.python_callable()
actual = func()
actual = func(a, b)
expected = [a, b]

assert isinstance(actual, list) and actual == expected
Expand Down
Loading