From a84bd23ef941fe98c91bf6d0ae7d20c1db997837 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 25 Sep 2025 07:55:08 -0700 Subject: [PATCH 1/8] Add test --- thunder/tests/test_dynamo.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 1a90e2fe89..05280dd856 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -844,6 +844,37 @@ def find_target_module(model, target_module_name): assert isinstance(n.target, Symbol) or callable(n.target) +@requiresCUDA +@pytest.mark.parametrize("op", [torch.sin, torch.sinc]) +def test_checkpoint_memory_use(op): + import torch.utils.checkpoint as checkpoint + + def fn(x): + return op(op(op(op(x)))) + + def checkpoint_fn(x): + return checkpoint.checkpoint(fn, x, use_reentrant=False) + + initial_mem = torch.cuda.memory_allocated() + + x = torch.randn((1024 // 4, 1024, 1024), device="cuda", requires_grad=True) + jfn = thunderfx(checkpoint_fn) + y = jfn(x) + + peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem + + y_ref = fn(x) + torch.testing.assert_close(y, y_ref) + + if op == torch.sin: + assert peak_mem_usage == x.nbytes * 2 + else: + assert peak_mem_usage == x.nbytes * 3 + # Make sure the checkpointed region falled back to PyTorch + sinfo = jfn._backend.subgraph_infos[-1] + assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], From 52d06aabf79c0b7c6c00b2af41e621db80d757d5 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 25 Sep 2025 08:56:13 -0700 Subject: [PATCH 2/8] Improved comments --- thunder/dynamo/compiler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2ee582bba0..1cf6b07041 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -135,10 +135,6 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor remove_empty_autocast(gm) - # Dynamo uses lazy generation of the underlying Python code, so we need to - # force recompilation of the GraphModule before passing it to Thunder. - recompile_graph(gm) - # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections # and unsupported sections which are passed to `torch.compile(backend='inductor')` thunder_options = _with_prologue_pruning_transform( From add4688d33ad884d88be95c5cd67cd9472e40a45 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 04:23:30 -0700 Subject: [PATCH 3/8] Resolve review: small tensor for tests --- thunder/tests/test_dynamo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 05280dd856..134648e09c 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -857,7 +857,7 @@ def checkpoint_fn(x): initial_mem = torch.cuda.memory_allocated() - x = torch.randn((1024 // 4, 1024, 1024), device="cuda", requires_grad=True) + x = torch.randn((128, 128), device="cuda", requires_grad=True) jfn = thunderfx(checkpoint_fn) y = jfn(x) From 128cc865924abacc55c0ed140ce891083c91d437 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 04:51:45 -0700 Subject: [PATCH 4/8] Wrap in submodule --- thunder/dynamo/splitter.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..51f1bcf28a 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -198,7 +198,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: ) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) - jit_fn = torch_inductor(graph_module) + + class Wrapped(torch.nn.Module): + def __init__(self, gm): + super().__init__() + self.gm = gm + + def forward(self, *a): + return self.gm(*a) + + # Make sure Inductor does not skip graph_module's compilation by wrapping it + # See https://github.com/Lightning-AI/lightning-thunder/issues/2527#issuecomment-3345877210 + jit_fn = torch_inductor(Wrapped(graph_module)) # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( From 5bcd895f15cfbc038624f9462e61e64d9c2bd8c0 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 04:51:51 -0700 Subject: [PATCH 5/8] Update test --- thunder/tests/test_dynamo.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 134648e09c..ea2cdf4de5 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -866,10 +866,8 @@ def checkpoint_fn(x): y_ref = fn(x) torch.testing.assert_close(y, y_ref) - if op == torch.sin: - assert peak_mem_usage == x.nbytes * 2 - else: - assert peak_mem_usage == x.nbytes * 3 + assert peak_mem_usage == x.nbytes * 2 + if op == torch.sinc: # Make sure the checkpointed region falled back to PyTorch sinfo = jfn._backend.subgraph_infos[-1] assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) From 8778e40894d7d98c328ff35bcc4c41654536c1d1 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 07:13:15 -0700 Subject: [PATCH 6/8] Add test for Inductor fallback --- thunder/tests/test_dynamo.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index ea2cdf4de5..2e2fc9843c 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -155,6 +155,24 @@ def func(x): assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` +@instantiate(dtypes=NOTHING) +def test_inductor_fallback(executor, device, dtype): + x = torch.randn(3, 3, device=device, dtype=dtype) + + def func(x): + return x.sinc().cos().sinc().sinc() + + def trivial_compile(model, *args, **kwargs): + return model + + cfunc = thunderfx(func) + with patch("torch._inductor.compile_fx.compile_fx", side_effect=trivial_compile) as mock_call: + cfunc(x) + + # Once for sinc() and once for sinc().sinc() + assert mock_call.call_count == 2 + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], From ec67c1874c2d0f21d7fa957d050d5156b9581577 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 08:22:48 -0700 Subject: [PATCH 7/8] Revert an unrelated change --- thunder/dynamo/compiler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 1cf6b07041..2ee582bba0 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -135,6 +135,10 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor remove_empty_autocast(gm) + # Dynamo uses lazy generation of the underlying Python code, so we need to + # force recompilation of the GraphModule before passing it to Thunder. + recompile_graph(gm) + # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections # and unsupported sections which are passed to `torch.compile(backend='inductor')` thunder_options = _with_prologue_pruning_transform( From 4cdaec71683e024b4df8eb4f556e7e3233fc4ca8 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 08:25:22 -0700 Subject: [PATCH 8/8] Fix typo --- thunder/tests/test_dynamo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 2e2fc9843c..5cc191b54a 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -886,7 +886,7 @@ def checkpoint_fn(x): assert peak_mem_usage == x.nbytes * 2 if op == torch.sinc: - # Make sure the checkpointed region falled back to PyTorch + # Make sure the checkpointed region fell back to PyTorch sinfo = jfn._backend.subgraph_infos[-1] assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes)