From 4157a20cec17069de2da1a46d1ca6f25ef890a70 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 18 Sep 2025 13:57:07 +0000 Subject: [PATCH 1/4] fix:orb squeeze incorrect energy shape --- torch_sim/models/orb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index fd65b23f..132f6d5c 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -416,7 +416,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].squeeze() + results[prop] = predictions[_property] if self.conservative: results["forces"] = results[self.model.grad_forces_name] From 51b6bd01a66142e2c78a09b79feb979d2ea34ec1 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 30 Dec 2025 11:35:22 +0100 Subject: [PATCH 2/4] remove use of private var --- torch_sim/models/mace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 2d5fe6c3..37f1a0ec 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -155,9 +155,9 @@ def __init__( # Load model if provided as path if isinstance(model, str | Path): - self.model = torch.load(model, map_location=self._device) + self.model = torch.load(model, map_location=self.device) elif isinstance(model, torch.nn.Module): - self.model = model.to(self._device) + self.model = model.to(self.device) else: raise TypeError("Model must be a path or torch.nn.Module") From 9bdf868eb774254936ae58dc0092611683a175cb Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 30 Dec 2025 11:35:40 +0100 Subject: [PATCH 3/4] add device with cueq for mace --- torch_sim/models/mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 37f1a0ec..fd0fde34 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -168,7 +168,7 @@ def __init__( if enable_cueq: print("Converting models to CuEq for acceleration") # noqa: T201 - self.model = run_e3nn_to_cueq(self.model) + self.model = run_e3nn_to_cueq(self.model, device=self.device.type) # Set model properties self.r_max = self.model.r_max From 3f3d3fe2cc919196dd404ea20ac4ff096d385235 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 30 Dec 2025 12:14:38 +0100 Subject: [PATCH 4/4] resolved conflict --- torch_sim/models/mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index fd0fde34..117b4d96 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -155,7 +155,7 @@ def __init__( # Load model if provided as path if isinstance(model, str | Path): - self.model = torch.load(model, map_location=self.device) + self.model = torch.load(model, map_location=self.device, weights_only=False) elif isinstance(model, torch.nn.Module): self.model = model.to(self.device) else: