diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 7d3a5d25..889adfae 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -155,9 +155,9 @@ def __init__( # Load model if provided as path if isinstance(model, str | Path): - self.model = torch.load(model, map_location=self._device, weights_only=False) + self.model = torch.load(model, map_location=self.device, weights_only=False) elif isinstance(model, torch.nn.Module): - self.model = model.to(self._device) + self.model = model.to(self.device) else: raise TypeError("Model must be a path or torch.nn.Module") @@ -170,7 +170,7 @@ def __init__( if enable_cueq: print("Converting models to CuEq for acceleration") # noqa: T201 - self.model = run_e3nn_to_cueq(self.model) + self.model = run_e3nn_to_cueq(self.model, device=self.device.type) # Set model properties self.r_max = self.model.r_max