Skip to content

Conversation

@nkemnitz
Copy link
Collaborator

@nkemnitz nkemnitz commented Aug 11, 2025

Passing the model to the JIT export subprocess will still share CUDA tensors. The trace can silently fail (see Warning) and modify the behavior of the model during training.

This PR:

  • Allows users to independently disable JIT and ONNX export
  • Fixes a few warnings related to (unnecessary) branching in our crop method (but that didn't solve the issues I encountered)
  • completely recreates the model inside the JIT export subprocess to avoid any side effects on the main process model - so slightly larger (GPU) memory footprint
  • still keeps the spawned process to avoid memory leaks
  • fixes a bug (I think it is) in which our exports were done in training mode
  • adds tests to verify the export functions are called when expected

@nkemnitz nkemnitz force-pushed the nkem/fix-jit-trace branch from d1b2c03 to 5f8a0b7 Compare August 11, 2025 11:38
@codecov
Copy link

codecov bot commented Aug 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.95%. Comparing base (05256cf) to head (82607e8).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1012   +/-   ##
=======================================
  Coverage   99.95%   99.95%           
=======================================
  Files         179      180    +1     
  Lines        8497     8518   +21     
=======================================
+ Hits         8493     8514   +21     
  Misses          4        4           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@nkemnitz nkemnitz force-pushed the nkem/fix-jit-trace branch 3 times, most recently from 2f006f2 to 743c32d Compare August 12, 2025 22:18
@supersergiy
Copy link
Member

Are memory leaks confirmed to be still an issue?

@nkemnitz
Copy link
Collaborator Author

The original issue is still open and their example still shows the same behavior.

@nkemnitz nkemnitz force-pushed the nkem/fix-jit-trace branch from 743c32d to 9a2f6ee Compare August 18, 2025 18:28
@trivoldus28
Copy link
Contributor

lgtm, but maybe the default should be false for both exports, or enable onnx and disable JIT, we don't need both and I just tested the onnx with the code in the torch github issue and it doesn't leak memory (JIT trace still does with torch 2.6)

@nkemnitz
Copy link
Collaborator Author

I have no strong opinion on the defaults. For this PR I just replicated the current behavior

@trivoldus28
Copy link
Contributor

trivoldus28 commented Aug 19, 2025

The current behavior is to always export ONNX? edit: Yea it does. I think there's no point exporting both, especially JIT. It's just slowing down training unnecessarily.

@trivoldus28
Copy link
Contributor

btw I took a second look at the warning linked and I don't think it means what you described (modify model during tracing).

@nkemnitz nkemnitz force-pushed the nkem/fix-jit-trace branch 2 times, most recently from a6974eb to d5e970d Compare August 20, 2025 21:22
@nkemnitz
Copy link
Collaborator Author

OK, after updating drivers and being forced to reinstall cudatoolkit, I can no longer replicate the exploding loss caused by JIT trace and/or export... so I reverted the deepcopy of the model including CUDA tensors. But at least added a bunch of tests to ensure checkpoint methods are correctly called / not called.

@trivoldus28
Copy link
Contributor

The model state was changing unexpectedly?

def test_save_checkpoint_calls_exports_when_enabled(
trainer_mocks, mock_model, mock_trace_input, mock_lightning_module
):
"""Test that save_checkpoint calls export functions when exports are enabled."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test necessary? isn't it simpler just to add an assert line in the code that checks at least 1 export function is enabled?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's different. I am OK if the user disables both exports, so that assert would be too restricting, anyway.

Here I am just testing that the parameters got passed down the chain and didn't get lost while unpacking/modifying **kwargs somewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's different. I am OK if the user disables both exports, so that assert would be too restricting, anyway.
Right, I misread the test.

But yea, still seems trivial? Is it really doing much more than a simple line coverage test? (i.e., call the save function & check that the saved file exists, etc.). Making sure the functions are called seems very marginally better.

def test_save_checkpoint_skips_jit_when_disabled(
trainer_mocks, mock_model, mock_trace_input, mock_lightning_module
):
"""Test that save_checkpoint skips JIT export when disabled."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems not necessary. if jit.trace is disabled in the input param it's clear that it's disabled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test verifies that the export logic actually respects the set parameter.

Here is a Copy & Paste error introduced by future-Us doing refactoring:

if self.enable_onnx_export:
      jit_trace_export(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple line coverage test would have caught this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, onnx_export might get called from a different location (either correctly or incorrectly), too. My point is: 100% line coverage is nice, but easy to achieve by cheating ourselves. Ideally I would like to ensure the method is called from all the expected code paths (and only those). 100% code coverage ensures no dead code, no broken code. But it does not ensure correct results.

But also, something is wrong here. I should not have been able to achieve 100% code coverage without these tests in the first place. Especially the error handling paths can't have been possibly covered before...



def test_save_checkpoint_skips_exports_non_global_zero(trainer_mocks):
"""Test that save_checkpoint skips exports when not global zero rank."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems useful, but would miss additional exports in the future

@nkemnitz
Copy link
Collaborator Author

The model state was changing unexpectedly?

Yeah, the CUDA tensors are shared betweem processes (even after deepcopy) and somehow just the serialization/deserialization step of the model was modifying the model itself, causing training loss to suddenly explode in my main process...

@trivoldus28
Copy link
Contributor

Have you actually confirmed that the training state changing is the root cause of the issue though? Seems unlikely to me that some external libraries/drivers would be specific enough to torch to change that specific variable and cause a corruption.

I would actually guess that on a configuration with exploding losses, the test would pass and you'd still get that error.

@trivoldus28
Copy link
Contributor

@nkemnitz Can you also disable jit.trace export by default?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants