Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def rematerialize(trace: TraceCtx) -> TraceCtx:
computed_cuts_for_producers[producer] += cut

rematerialized_trace = from_trace(trace)
rematerialized_trace.bound_symbols = tuple(new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols)
rematerialized_trace.bound_symbols = list(new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols)

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
6 changes: 6 additions & 0 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def tag_tensorproxy_output_as_detached(proxy):
exception_type=AssertionError,
)

# When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed.
from thunder.core.transform_common import dce_bsyms

subsymbols = dce_bsyms(subsymbols, result)
bsym = bsym.from_bsym(subsymbols=subsymbols)

symbols_list.append(bsym)
return result

Expand Down
39 changes: 30 additions & 9 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,32 @@ def keep_or_swap(p):
# that only produce non-proxy objects
# NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return
# all the needed proxies of the input trace
def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
start_time_ns = time.perf_counter_ns()
def dce_bsyms(
bsyms: list[BoundSymbolInterface],
output: Any,
needed_proxies: None | set[Variable] = None,
) -> Trace | list[BoundSymbolInterface]:
"""Runs a Dead Code Elimination (DCE) pass

Args:
bsyms: The list of bound symbols to run the DCE pass on.
needed_proxies: The set of variables to keep.
output: The output of the list of bound symbols.

producer_map: ProxyDict = producers(trace)
Returns:
The list of bound symbols after the DCE pass.
"""
producer_map: ProxyDict = producers(bsyms)

flat_trace_outputs, _ = tree_flatten(trace.output)
flat_trace_outputs, _ = tree_flatten(output)
if needed_proxies is None:
needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)))
else:
needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)))
dced = []

bsym: BoundSymbol
for bsym in reversed(trace.bound_symbols):
for bsym in reversed(bsyms):
# Preserves symbols that should never be collected
if has_tags(bsym, {prims.OpTags.DONT_DCE}):
needed = True
Expand All @@ -182,19 +194,28 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
for x in nbsym.flat_proxy_args:
needed_proxies.add(variableify(x))

dcetrace = from_trace(trace)
dced_bound_symbols = list(reversed(dced))
# duplicate number proxies happen with the symbolic shapes and are
# not covered by the above (due to being in tuples?).
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
dcetrace.bound_symbols = dced_bound_symbols

return dced_bound_symbols


def dce(trace: Trace, needed_proxies: set[Variable] = None) -> Trace:
start_time_ns = time.perf_counter_ns()

bsyms = trace.bound_symbols
dced_bsyms = dce_bsyms(bsyms, trace.output, needed_proxies)
result = from_trace(trace)
result.bound_symbols = dced_bsyms

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)"))

return dcetrace
result.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)"))
return result


#
Expand Down
15 changes: 9 additions & 6 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,20 @@ def get_backed_value(s):
return tuple(map(get_backed_value, vals))


def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:
def get_proxy_inputs_from_node(node: torch.fx.Node, tracectx) -> tuple[tuple, dict]:
"""Creates proxy inputs from a torch.fx.Node for use with Thunder.

This function generates proxy inputs for a given torch.fx.Node

Args:
node (torch.fx.Node): The FX graph node to create proxy inputs for.
tracectx (TraceCtx): The trace context to use to generate proxy inputs.
"""
import thunder
from thunder.core.trace import TraceCtx
from thunder.core.proxies import proxy

# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):

def make_input_proxy(arg_node):
# This is a Node in the graph representing a Tensor or tuple of Tensors or
Expand Down Expand Up @@ -380,8 +380,10 @@ def _run_with_cache_info():
cache_info["default_dtype"] = torch.get_default_dtype()
cache_info["default_device"] = torch.get_default_device()

tracectx = TraceCtx()

try:
proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node)
proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node, tracectx)
except Exception as e:
return False, SplitReason(
SplitReasonType.EXCEPTION_PROXY_THUNDER_OP,
Expand All @@ -395,7 +397,7 @@ def _run_with_cache_info():
else thunder_symbol
)
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):
try:
function_to_run(*proxy_args, **proxy_kwargs)
except Exception as e:
Expand Down Expand Up @@ -478,6 +480,7 @@ def is_node_supported_by_thunder(
"""
Determine whether thunder can execute the operation described by this node.
"""
from thunder.core.trace import TraceCtx
# Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node
# Each Node has a function specified by its op property
# Below are the details for the ones this function is interested in -
Expand Down Expand Up @@ -555,7 +558,7 @@ def is_node_supported_by_thunder(
if torchctx.has_method(node.target):
# `torchctx.get_method` requires args and kwargs to resolve which overload of the method is picked.
try:
args, kwargs = get_proxy_inputs_from_node(node)
args, kwargs = get_proxy_inputs_from_node(node, TraceCtx())
except Exception as e:
return False, SplitReason(
SplitReasonType.EXCEPTION_PROXY_THUNDER_OP,
Expand Down
9 changes: 3 additions & 6 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from thunder.core.trace import TraceCtx, from_trace, TraceProvenance
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, Symbol, has_tags
from thunder.core.devices import Device, DeviceType, cpu
from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs
from thunder.core.transform_common import dce, dce_bsyms, cse_single_bsym, replace_redundant_inputs
from thunder.core.profile import annotate_for_profile
from thunder.core.compile_data import get_compile_option
from thunder.torch.experimental.dtensor_torch_and_prims import DTensorPrimIDs
Expand Down Expand Up @@ -706,14 +706,11 @@ def has_cuda_input_or_output(self, bsym: BoundSymbol) -> bool:
return False

def _dce_bsyms(self, input_list, output, bsyms: list[BoundSymbol]) -> list[BoundSymbol]:
trace = TraceCtx(None)
trace.bound_symbols = bsyms
bsyms.append(prims.python_return.bind(output, output=None))
needed_proxies: set[Variable] = set()
trace = dce(trace, needed_proxies)
bsyms = dce_bsyms(bsyms, output, needed_proxies)
# update the input_list by removing the unused inputs
input_list[:] = [x for x in input_list if variableify(x) in needed_proxies]
return list(filter(lambda x: x.sym != prims.python_return, trace.bound_symbols))
return bsyms

def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol:
sorted_unique_inputs: list[Proxy] = [unvariableify(x) for x in region.inputs]
Expand Down
50 changes: 43 additions & 7 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,36 @@ def func(x, y, device):
assert [t.name for t in tree_flatten(flatten_cse_trace.output)[0]] == ["t4", "t4", "t6", "t14", "t15", "t16", "t17"]


@instantiate(
dtypes=NOTHING,
)
def test_dce(executor, device, _):
def func(x):
dead_code = x + 1 # noqa: F841
y = x * x
return y

x = make_tensor((2, 2), device=device, dtype=torch.float32)
compiled = thunder.jit(func, executors=executor.executors_list())
compiled(x)
traces = thunder.last_traces(compiled)

# find last trace before DCE is applied
for i, trace in enumerate(traces):
provenance = trace.get_provenance().pss if trace.get_provenance() else ""
if "Dead Code Elimination" in provenance:
break
i -= 1
trace = traces[i]

from thunder.core.transform_common import dce, dce_bsyms

dced_trace = dce(trace)
dced_bsyms = dce_bsyms(trace.bound_symbols, trace.output)
assert len(dced_trace.bound_symbols) == len(trace.bound_symbols) - 1
assert len(dced_trace.bound_symbols) == len(dced_bsyms)


def test_symbol_flat_args():
from thunder.core.symbol import Symbol, BoundSymbol

Expand Down Expand Up @@ -3295,22 +3325,28 @@ def clean(tr):


def test_prims_pack_list():
def foo():
pass

trace = TraceCtx(foo)
def foo(x):
a, b = x
return [a, b]

a = torch.randn(2, 2)
b = torch.randn(2, 2)

jfoo = thunder.jit(foo)
jfoo((a, b))

trace = thunder.last_traces(jfoo)[-1]

return_bsym = trace.bound_symbols[-1]
trace.bound_symbols = trace.bound_symbols[:-1]

with tracectx(trace):
x = prims.unpack_trivial(a, name="x")
y = prims.unpack_trivial(b, name="y")
x, y = return_bsym.flat_args
packed_list = prims.pack_list(x, y)
prims.python_return(packed_list)

func = trace.python_callable()
actual = func()
actual = func(a, b)
expected = [a, b]

assert isinstance(actual, list) and actual == expected
Expand Down
Loading