[loading] Fix device when source and target are different #42246
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
The device_map specifies the target keys when loading. The PR updates the loading accordingly, otherwise we have issues when using a device_map with any model using a
_conversion_mapping(the VLMs) for example, where source and targets are different.Currently, the only thing saving us is the fact that
accelerate_dispatchwill move the parameters if the device is not correct during post processing, which is why it was not detected before! But of course this is muuuuch more costly than our smart loading.I checked very carefully (by running benchmarks AND checking source code), and performances are the same if using this PR, or if opening safetensors directly on device! This can also be verified by looking at the
safetensorsrust bindings here: it actually simply callstensor.to(device)internally when callingget_slice, so this PR has no impact on performances (may even be slightly better due to avoiding to opening the files again and again)To understand the issue, consider the following snippet:
On main, it currently outputs:
Basically, all params are on
cpuinstead of0, due to the mismatch between targets and sources in the_checkpoint_conversion_mapping.On this PR, everything is fine again, and params are loaded immediately on the correct device.
This is also needed for my other offloading PR #42242