Skip to content

Conversation

nimanikoo
Copy link

While reviewing the code in this repository, I noticed a few areas that could be optimized for efficiency. I decided to make some changes to how the models are loaded onto the GPU before applying their checkpoints. I believe this should have a positive impact on the performance and overall behavior of the code.

Thanks to everyone who contributed to this repo—really appreciate all the hard work that went into it

Summary

This PR ensures that both prefill_model and decode_model are moved to the
target device (e.g., GPU) before invoking load_state_dict.

Motivation

Previously, if the models were still on CPU when loading checkpoints, PyTorch
would perform an additional transfer of tensors, causing unnecessary overhead.
By explicitly moving the models to the correct device first, we avoid redundant
transfers and improve checkpoint loading efficiency.

Changes

  • Added prefill_model.to(device) before loading its state dict.
  • Added decode_model.to(device) before loading its state dict.

Impact

This reduces unnecessary GPU/CPU transfers during checkpoint loading, which
should result in faster and more efficient model initialization.

…ng state_dict

Moved `prefill_model` and `decode_model` to the target device before calling
`load_state_dict` to avoid redundant tensor transfers by PyTorch.
@farazaaa
Copy link

farazaaa commented Oct 5, 2025

Let's go it' really good 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants