Optimize GPU checkpoint loading by ensuring model transfer before load_state_dict on build method #325
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
While reviewing the code in this repository, I noticed a few areas that could be optimized for efficiency. I decided to make some changes to how the models are loaded onto the GPU before applying their checkpoints. I believe this should have a positive impact on the performance and overall behavior of the code.
Thanks to everyone who contributed to this repo—really appreciate all the hard work that went into it
Summary
This PR ensures that both
prefill_model
anddecode_model
are moved to thetarget device (e.g., GPU) before invoking
load_state_dict
.Motivation
Previously, if the models were still on CPU when loading checkpoints, PyTorch
would perform an additional transfer of tensors, causing unnecessary overhead.
By explicitly moving the models to the correct device first, we avoid redundant
transfers and improve checkpoint loading efficiency.
Changes
prefill_model.to(device)
before loading its state dict.decode_model.to(device)
before loading its state dict.Impact
This reduces unnecessary GPU/CPU transfers during checkpoint loading, which
should result in faster and more efficient model initialization.