-
Notifications
You must be signed in to change notification settings - Fork 0
Fix JIT trace/export training corruption #1012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d1b2c03 to
5f8a0b7
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1012 +/- ##
=======================================
Coverage 99.95% 99.95%
=======================================
Files 179 180 +1
Lines 8497 8518 +21
=======================================
+ Hits 8493 8514 +21
Misses 4 4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
2f006f2 to
743c32d
Compare
|
Are memory leaks confirmed to be still an issue? |
|
The original issue is still open and their example still shows the same behavior. |
743c32d to
9a2f6ee
Compare
|
lgtm, but maybe the default should be false for both exports, or enable onnx and disable JIT, we don't need both and I just tested the onnx with the code in the torch github issue and it doesn't leak memory (JIT trace still does with torch 2.6) |
|
I have no strong opinion on the defaults. For this PR I just replicated the current behavior |
|
The current behavior is to always export ONNX? edit: Yea it does. I think there's no point exporting both, especially JIT. It's just slowing down training unnecessarily. |
|
btw I took a second look at the warning linked and I don't think it means what you described (modify model during tracing). |
a6974eb to
d5e970d
Compare
|
OK, after updating drivers and being forced to reinstall cudatoolkit, I can no longer replicate the exploding loss caused by JIT trace and/or export... so I reverted the deepcopy of the model including CUDA tensors. But at least added a bunch of tests to ensure checkpoint methods are correctly called / not called. |
|
The model state was changing unexpectedly? |
| def test_save_checkpoint_calls_exports_when_enabled( | ||
| trainer_mocks, mock_model, mock_trace_input, mock_lightning_module | ||
| ): | ||
| """Test that save_checkpoint calls export functions when exports are enabled.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this test necessary? isn't it simpler just to add an assert line in the code that checks at least 1 export function is enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's different. I am OK if the user disables both exports, so that assert would be too restricting, anyway.
Here I am just testing that the parameters got passed down the chain and didn't get lost while unpacking/modifying **kwargs somewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's different. I am OK if the user disables both exports, so that assert would be too restricting, anyway.
Right, I misread the test.
But yea, still seems trivial? Is it really doing much more than a simple line coverage test? (i.e., call the save function & check that the saved file exists, etc.). Making sure the functions are called seems very marginally better.
| def test_save_checkpoint_skips_jit_when_disabled( | ||
| trainer_mocks, mock_model, mock_trace_input, mock_lightning_module | ||
| ): | ||
| """Test that save_checkpoint skips JIT export when disabled.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems not necessary. if jit.trace is disabled in the input param it's clear that it's disabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test verifies that the export logic actually respects the set parameter.
Here is a Copy & Paste error introduced by future-Us doing refactoring:
if self.enable_onnx_export:
jit_trace_export(...)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A simple line coverage test would have caught this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, onnx_export might get called from a different location (either correctly or incorrectly), too. My point is: 100% line coverage is nice, but easy to achieve by cheating ourselves. Ideally I would like to ensure the method is called from all the expected code paths (and only those). 100% code coverage ensures no dead code, no broken code. But it does not ensure correct results.
But also, something is wrong here. I should not have been able to achieve 100% code coverage without these tests in the first place. Especially the error handling paths can't have been possibly covered before...
|
|
||
|
|
||
| def test_save_checkpoint_skips_exports_non_global_zero(trainer_mocks): | ||
| """Test that save_checkpoint skips exports when not global zero rank.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems useful, but would miss additional exports in the future
Yeah, the CUDA tensors are shared betweem processes (even after deepcopy) and somehow just the serialization/deserialization step of the model was modifying the model itself, causing training loss to suddenly explode in my main process... |
|
Have you actually confirmed that the training state changing is the root cause of the issue though? Seems unlikely to me that some external libraries/drivers would be specific enough to torch to change that specific variable and cause a corruption. I would actually guess that on a configuration with exploding losses, the test would pass and you'd still get that error. |
|
@nkemnitz Can you also disable jit.trace export by default? |
4132b0a to
82607e8
Compare
Passing the model to the JIT export subprocess will still share CUDA tensors. The trace can silently fail (see Warning) and modify the behavior of the model during training.This PR:
completely recreates the model inside the JIT export subprocess to avoid any side effects on the main process model - so slightly larger (GPU) memory footprint