Skip to content
6 changes: 3 additions & 3 deletions torch_sim/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def __init__(

# Load model if provided as path
if isinstance(model, str | Path):
self.model = torch.load(model, map_location=self._device, weights_only=False)
self.model = torch.load(model, map_location=self.device, weights_only=False)
elif isinstance(model, torch.nn.Module):
self.model = model.to(self._device)
self.model = model.to(self.device)
else:
raise TypeError("Model must be a path or torch.nn.Module")

Expand All @@ -170,7 +170,7 @@ def __init__(

if enable_cueq:
print("Converting models to CuEq for acceleration") # noqa: T201
self.model = run_e3nn_to_cueq(self.model)
self.model = run_e3nn_to_cueq(self.model, device=self.device.type)

# Set model properties
self.r_max = self.model.r_max
Expand Down
Loading