From e7c8bc98234f7e0adf9da0d44cf07ce7c953c96a Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Thu, 16 Oct 2025 14:09:54 +0300 Subject: [PATCH 1/6] Apply DCE to subsymbols --- thunder/core/transform_common.py | 34 +++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bc14a671f2..cbdc1064e9 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -131,7 +131,9 @@ def keep_or_swap(p): if all(map(lambda x: isinstance(x, NumberProxyInterface) and x.name in seen, bsym.flat_outs)): continue output = tree_map(keep_or_swap, bsym.output) - new_bsyms.append(bsym.from_bsym(output=output)) + subsymbols = dce(bsym.subsymbols, output=bsym.output) + new_bsym = bsym.from_bsym(output=output, subsymbols=subsymbols) + new_bsyms.append(new_bsym) return new_bsyms @@ -142,12 +144,21 @@ def keep_or_swap(p): # that only produce non-proxy objects # NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return # all the needed proxies of the input trace -def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: +def dce( + trace_or_bsyms: Trace | list[BoundSymbolInterface], needed_proxies: None | set[Variable] = None, output=None +) -> Trace | list[BoundSymbolInterface]: start_time_ns = time.perf_counter_ns() - producer_map: ProxyDict = producers(trace) + producer_map: ProxyDict = producers(trace_or_bsyms) + + if isinstance(trace_or_bsyms, Trace): + bound_symbols = trace_or_bsyms.bound_symbols + output = trace_or_bsyms.output + else: + bound_symbols = trace_or_bsyms + output = output - flat_trace_outputs, _ = tree_flatten(trace.output) + flat_trace_outputs, _ = tree_flatten(output) if needed_proxies is None: needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) else: @@ -155,7 +166,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: dced = [] bsym: BoundSymbol - for bsym in reversed(trace.bound_symbols): + for bsym in reversed(bound_symbols): # Preserves symbols that should never be collected if has_tags(bsym, {prims.OpTags.DONT_DCE}): needed = True @@ -182,19 +193,24 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: for x in nbsym.flat_proxy_args: needed_proxies.add(variableify(x)) - dcetrace = from_trace(trace) dced_bound_symbols = list(reversed(dced)) # duplicate number proxies happen with the symbolic shapes and are # not covered by the above (due to being in tuples?). dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) - dcetrace.bound_symbols = dced_bound_symbols + if isinstance(trace_or_bsyms, Trace): + result = from_trace(trace_or_bsyms) + result.bound_symbols = dced_bound_symbols + else: + result = dced_bound_symbols end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)")) - return dcetrace + if isinstance(trace_or_bsyms, Trace): + result.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)")) + + return result # From cb31f7fb6ea0c60ef65ca694ae9365b37addc008 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Wed, 12 Nov 2025 14:41:33 +0200 Subject: [PATCH 2/6] try in symbol.__call__ instead --- thunder/core/symbol.py | 6 +++++- thunder/core/transform_common.py | 4 +--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index f5184b5770..30b25cff8d 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -323,6 +323,11 @@ def __call__(self, *args, **kwargs): result = tree_unflatten(flat_results, spec) + # When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed. + from thunder.core.transform_common import dce + + subsymbols = dce(subsymbols, output=result) + trace.pop_scope() cd = get_compile_data() @@ -340,7 +345,6 @@ def tag_tensorproxy_output_as_detached(proxy): return proxy result = tree_map(tag_tensorproxy_output_as_detached, result) - bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols) symbols_list = trace.peek_scope() diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index cbdc1064e9..2ecdda3d19 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -131,9 +131,7 @@ def keep_or_swap(p): if all(map(lambda x: isinstance(x, NumberProxyInterface) and x.name in seen, bsym.flat_outs)): continue output = tree_map(keep_or_swap, bsym.output) - subsymbols = dce(bsym.subsymbols, output=bsym.output) - new_bsym = bsym.from_bsym(output=output, subsymbols=subsymbols) - new_bsyms.append(new_bsym) + new_bsyms.append(bsym.from_bsym(output=output)) return new_bsyms From db83930c36bfd77e27c5e4df2f4fe9c824595995 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Fri, 14 Nov 2025 16:31:15 +0200 Subject: [PATCH 3/6] come on, ruff, it's a test --- thunder/core/symbol.py | 11 ++++--- thunder/core/transform_common.py | 16 +++++++-- thunder/dynamo/utils.py | 15 +++++---- thunder/executors/nvfuserex_impl.py | 7 ++-- thunder/tests/test_core.py | 51 +++++++++++++++++++++++++---- 5 files changed, 75 insertions(+), 25 deletions(-) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 30b25cff8d..6dd65cfaa5 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -323,11 +323,6 @@ def __call__(self, *args, **kwargs): result = tree_unflatten(flat_results, spec) - # When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed. - from thunder.core.transform_common import dce - - subsymbols = dce(subsymbols, output=result) - trace.pop_scope() cd = get_compile_data() @@ -354,6 +349,12 @@ def tag_tensorproxy_output_as_detached(proxy): exception_type=AssertionError, ) + # When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed. + from thunder.core.transform_common import dce + + subsymbols = dce(subsymbols, output=result) + bsym = bsym.from_bsym(subsymbols=subsymbols) + symbols_list.append(bsym) return result diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 2ecdda3d19..928e94eeb0 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -143,8 +143,21 @@ def keep_or_swap(p): # NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return # all the needed proxies of the input trace def dce( - trace_or_bsyms: Trace | list[BoundSymbolInterface], needed_proxies: None | set[Variable] = None, output=None + trace_or_bsyms: Trace | list[BoundSymbolInterface], + needed_proxies: None | set[Variable] = None, + output: Any = None, ) -> Trace | list[BoundSymbolInterface]: + """Runs a Dead Code Elimination (DCE) pass + + Args: + trace_or_bsyms: The trace or list of bound symbols to run the DCE pass on. + needed_proxies: The set of variables to keep. + output: The output of the list of bound symbols. This is only used if the input is a list of bound + symbols, and is required in that case + + Returns: + The trace (if the input is a trace) or list of bound symbols (if the input is a list of bound symbols) after the DCE pass. + """ start_time_ns = time.perf_counter_ns() producer_map: ProxyDict = producers(trace_or_bsyms) @@ -154,7 +167,6 @@ def dce( output = trace_or_bsyms.output else: bound_symbols = trace_or_bsyms - output = output flat_trace_outputs, _ = tree_flatten(output) if needed_proxies is None: diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index ad929b5adf..2c5a5fa900 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -238,20 +238,20 @@ def get_backed_value(s): return tuple(map(get_backed_value, vals)) -def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]: +def get_proxy_inputs_from_node(node: torch.fx.Node, tracectx) -> tuple[tuple, dict]: """Creates proxy inputs from a torch.fx.Node for use with Thunder. This function generates proxy inputs for a given torch.fx.Node Args: node (torch.fx.Node): The FX graph node to create proxy inputs for. + tracectx (TraceCtx): The trace context to use to generate proxy inputs. """ import thunder - from thunder.core.trace import TraceCtx from thunder.core.proxies import proxy # We need to be under trace context to generate proxies. - with thunder.core.trace.tracectx(TraceCtx()): + with thunder.core.trace.tracectx(tracectx): def make_input_proxy(arg_node): # This is a Node in the graph representing a Tensor or tuple of Tensors or @@ -359,8 +359,10 @@ def _run_with_cache_info(): cache_info["default_dtype"] = torch.get_default_dtype() cache_info["default_device"] = torch.get_default_device() + tracectx = TraceCtx() + try: - proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node) + proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node, tracectx) except Exception as e: return False, SplitReason( SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, @@ -374,7 +376,7 @@ def _run_with_cache_info(): else thunder_symbol ) # We need to be under trace context to generate proxies. - with thunder.core.trace.tracectx(TraceCtx()): + with thunder.core.trace.tracectx(tracectx): try: function_to_run(*proxy_args, **proxy_kwargs) except Exception as e: @@ -442,6 +444,7 @@ def is_node_supported_by_thunder( """ Determine whether thunder can execute the operation described by this node. """ + from thunder.core.trace import TraceCtx # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node # Each Node has a function specified by its op property # Below are the details for the ones this function is interested in - @@ -513,7 +516,7 @@ def is_node_supported_by_thunder( if torchctx.has_method(node.target): # `torchctx.get_method` requires args and kwargs to resolve which overload of the method is picked. try: - args, kwargs = get_proxy_inputs_from_node(node) + args, kwargs = get_proxy_inputs_from_node(node, TraceCtx()) except Exception as e: return False, SplitReason( SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index b65d6944e0..1c326d7b3f 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -706,14 +706,11 @@ def has_cuda_input_or_output(self, bsym: BoundSymbol) -> bool: return False def _dce_bsyms(self, input_list, output, bsyms: list[BoundSymbol]) -> list[BoundSymbol]: - trace = TraceCtx(None) - trace.bound_symbols = bsyms - bsyms.append(prims.python_return.bind(output, output=None)) needed_proxies: set[Variable] = set() - trace = dce(trace, needed_proxies) + bsyms = dce(bsyms, needed_proxies, output) # update the input_list by removing the unused inputs input_list[:] = [x for x in input_list if variableify(x) in needed_proxies] - return list(filter(lambda x: x.sym != prims.python_return, trace.bound_symbols)) + return bsyms def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol: sorted_unique_inputs: list[Proxy] = [unvariableify(x) for x in region.inputs] diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4210fdca67..53806c1303 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2170,6 +2170,36 @@ def func(x, y, device): assert [t.name for t in tree_flatten(flatten_cse_trace.output)[0]] == ["t4", "t4", "t6", "t14", "t15", "t16", "t17"] +@instantiate( + dtypes=NOTHING, +) +def test_dce(executor, device, _): + def func(x): + dead_code = x + 1 # noqa: F841 + y = x * x + return y + + x = make_tensor((2, 2), device=device, dtype=torch.float32) + compiled = thunder.jit(func, executors=executor.executors_list()) + compiled(x) + traces = thunder.last_traces(compiled) + + # find last trace before DCE is applied + for i, trace in enumerate(traces): + provenance = trace.get_provenance().pss if trace.get_provenance() else "" + if "Dead Code Elimination" in provenance: + break + i -= 1 + trace = traces[i] + + from thunder.core.transform_common import dce + + dced_trace = dce(trace) + dced_bsyms = dce(trace.bound_symbols) + assert len(dced_trace.bound_symbols) == len(trace.bound_symbols) - 1 + assert len(dced_trace.bound_symbols) == len(dced_bsyms) + + def test_symbol_flat_args(): from thunder.core.symbol import Symbol, BoundSymbol @@ -3295,22 +3325,29 @@ def clean(tr): def test_prims_pack_list(): - def foo(): - pass - - trace = TraceCtx(foo) + def foo(x): + a, b = x + return [a, b] a = torch.randn(2, 2) b = torch.randn(2, 2) + jfoo = thunder.jit(foo) + jfoo((a, b)) + + trace = thunder.last_traces(jfoo)[-1] + + return_bsym = trace.bound_symbols[-1] + trace.bound_symbols = trace.bound_symbols[:-1] + with tracectx(trace): - x = prims.unpack_trivial(a, name="x") - y = prims.unpack_trivial(b, name="y") + x, y = return_bsym.flat_args packed_list = prims.pack_list(x, y) prims.python_return(packed_list) func = trace.python_callable() - actual = func() + print(trace) + actual = func(a, b) expected = [a, b] assert isinstance(actual, list) and actual == expected From 53630676a71b69ce54211e0001a07f4142153804 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 18 Nov 2025 14:41:51 +0200 Subject: [PATCH 4/6] remove print --- thunder/tests/test_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4fe1c5ba88..eeea8f492d 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3346,7 +3346,6 @@ def foo(x): prims.python_return(packed_list) func = trace.python_callable() - print(trace) actual = func(a, b) expected = [a, b] From 129e9d03e6e9862f72e419469321c3cc14460961 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Fri, 21 Nov 2025 11:49:21 +0200 Subject: [PATCH 5/6] respond to comments --- thunder/core/symbol.py | 5 ++-- thunder/core/transform_common.py | 45 +++++++++++++---------------- thunder/executors/nvfuserex_impl.py | 4 +-- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 6dd65cfaa5..02d4533b54 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -340,6 +340,7 @@ def tag_tensorproxy_output_as_detached(proxy): return proxy result = tree_map(tag_tensorproxy_output_as_detached, result) + bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols) symbols_list = trace.peek_scope() @@ -350,9 +351,9 @@ def tag_tensorproxy_output_as_detached(proxy): ) # When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed. - from thunder.core.transform_common import dce + from thunder.core.transform_common import dce_bsyms - subsymbols = dce(subsymbols, output=result) + subsymbols = dce_bsyms(subsymbols, result) bsym = bsym.from_bsym(subsymbols=subsymbols) symbols_list.append(bsym) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 928e94eeb0..32916cfded 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -142,31 +142,22 @@ def keep_or_swap(p): # that only produce non-proxy objects # NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return # all the needed proxies of the input trace -def dce( - trace_or_bsyms: Trace | list[BoundSymbolInterface], +def dce_bsyms( + bsyms: list[BoundSymbolInterface], + output: Any, needed_proxies: None | set[Variable] = None, - output: Any = None, ) -> Trace | list[BoundSymbolInterface]: """Runs a Dead Code Elimination (DCE) pass Args: - trace_or_bsyms: The trace or list of bound symbols to run the DCE pass on. + bsyms: The list of bound symbols to run the DCE pass on. needed_proxies: The set of variables to keep. - output: The output of the list of bound symbols. This is only used if the input is a list of bound - symbols, and is required in that case + output: The output of the list of bound symbols. Returns: - The trace (if the input is a trace) or list of bound symbols (if the input is a list of bound symbols) after the DCE pass. + The list of bound symbols after the DCE pass. """ - start_time_ns = time.perf_counter_ns() - - producer_map: ProxyDict = producers(trace_or_bsyms) - - if isinstance(trace_or_bsyms, Trace): - bound_symbols = trace_or_bsyms.bound_symbols - output = trace_or_bsyms.output - else: - bound_symbols = trace_or_bsyms + producer_map: ProxyDict = producers(bsyms) flat_trace_outputs, _ = tree_flatten(output) if needed_proxies is None: @@ -176,7 +167,7 @@ def dce( dced = [] bsym: BoundSymbol - for bsym in reversed(bound_symbols): + for bsym in reversed(bsyms): # Preserves symbols that should never be collected if has_tags(bsym, {prims.OpTags.DONT_DCE}): needed = True @@ -208,18 +199,22 @@ def dce( # not covered by the above (due to being in tuples?). dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) - if isinstance(trace_or_bsyms, Trace): - result = from_trace(trace_or_bsyms) - result.bound_symbols = dced_bound_symbols - else: - result = dced_bound_symbols + return dced_bound_symbols + + +def dce(trace: Trace, needed_proxies: set[Variable]) -> Trace: + start_time_ns = time.perf_counter_ns() + + bsyms = trace.bound_symbol + dced_bsyms = dce_bsyms(bsyms, trace.output, needed_proxies) + result = from_trace(trace) + result.bound_symbols = dced_bsyms + end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - if isinstance(trace_or_bsyms, Trace): - result.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)")) - + result.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)")) return result diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 1c326d7b3f..b3fe70f7ca 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -42,7 +42,7 @@ from thunder.core.trace import TraceCtx, from_trace, TraceProvenance from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, Symbol, has_tags from thunder.core.devices import Device, DeviceType, cpu -from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs +from thunder.core.transform_common import dce, dce_bsyms, cse_single_bsym, replace_redundant_inputs from thunder.core.profile import annotate_for_profile from thunder.core.compile_data import get_compile_option from thunder.torch.experimental.dtensor_torch_and_prims import DTensorPrimIDs @@ -707,7 +707,7 @@ def has_cuda_input_or_output(self, bsym: BoundSymbol) -> bool: def _dce_bsyms(self, input_list, output, bsyms: list[BoundSymbol]) -> list[BoundSymbol]: needed_proxies: set[Variable] = set() - bsyms = dce(bsyms, needed_proxies, output) + bsyms = dce_bsyms(bsyms, output, needed_proxies) # update the input_list by removing the unused inputs input_list[:] = [x for x in input_list if variableify(x) in needed_proxies] return bsyms From 06ca4df4f220c2cdc198355c3805aba9099f772d Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Fri, 21 Nov 2025 12:01:41 +0200 Subject: [PATCH 6/6] where's my coffee --- thunder/core/rematerialization.py | 2 +- thunder/core/transform_common.py | 4 ++-- thunder/tests/test_core.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index a3b3b83504..44a8a9cd67 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -458,7 +458,7 @@ def rematerialize(trace: TraceCtx) -> TraceCtx: computed_cuts_for_producers[producer] += cut rematerialized_trace = from_trace(trace) - rematerialized_trace.bound_symbols = tuple(new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols) + rematerialized_trace.bound_symbols = list(new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols) end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 32916cfded..b2f6ca8140 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -202,10 +202,10 @@ def dce_bsyms( return dced_bound_symbols -def dce(trace: Trace, needed_proxies: set[Variable]) -> Trace: +def dce(trace: Trace, needed_proxies: set[Variable] = None) -> Trace: start_time_ns = time.perf_counter_ns() - bsyms = trace.bound_symbol + bsyms = trace.bound_symbols dced_bsyms = dce_bsyms(bsyms, trace.output, needed_proxies) result = from_trace(trace) result.bound_symbols = dced_bsyms diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index eeea8f492d..d2cfb938ce 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2192,10 +2192,10 @@ def func(x): i -= 1 trace = traces[i] - from thunder.core.transform_common import dce + from thunder.core.transform_common import dce, dce_bsyms dced_trace = dce(trace) - dced_bsyms = dce(trace.bound_symbols) + dced_bsyms = dce_bsyms(trace.bound_symbols, trace.output) assert len(dced_trace.bound_symbols) == len(trace.bound_symbols) - 1 assert len(dced_trace.bound_symbols) == len(dced_bsyms)