From 02a54e1f465d1718501be121f780b50f4211ae03 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 25 Nov 2025 14:31:16 +0000 Subject: [PATCH 01/29] append to trajectory --- torch_sim/runners.py | 24 +++++++++++--- torch_sim/trajectory.py | 73 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 84 insertions(+), 13 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b2059aac..f6b72ea5 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -175,7 +175,12 @@ def integrate[T: SimState]( # noqa: C901 trajectory_reporter, properties=["kinetic_energy", "potential_energy", "temperature"], ) - + # Auto-detect initial step from trajectory files for resuming integration + initial_step: int = 1 + if trajectory_reporter is not None and trajectory_reporter.mode == "a": + last_step = trajectory_reporter.last_step + if last_step > 0: + initial_step = last_step + 1 final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -199,7 +204,7 @@ def integrate[T: SimState]( # noqa: C901 ) # run the simulation - for step in range(1, n_steps + 1): + for step in range(initial_step, initial_step + n_steps): state = step_func( state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs ) @@ -457,7 +462,14 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 trajectory_reporter, properties=["potential_energy"] ) - step: int = 1 + # Auto-detect initial step from trajectory files for resuming optimizations + initial_step: int = 1 + if trajectory_reporter is not None and trajectory_reporter.mode == "a": + last_step = trajectory_reporter.last_step + if last_step > 0: + initial_step = last_step + 1 + step: int = initial_step + last_energy = None all_converged_states: list[T] = [] convergence_tensor = None @@ -485,9 +497,13 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 and og_filenames is not None and (step == 1 or len(converged_states) > 0) ): + mode_before = trajectory_reporter.trajectory_kwargs["mode"] + # temporarily set to "append" mode to avoid overwriting existing files + trajectory_reporter.trajectory_kwargs["mode"] = "a" trajectory_reporter.load_new_trajectories( filenames=[og_filenames[i] for i in autobatcher.current_idx] ) + trajectory_reporter.trajectory_kwargs["mode"] = mode_before for _step in range(steps_between_swaps): if hasattr(state, "energy"): @@ -498,7 +514,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 if trajectory_reporter: trajectory_reporter.report(state, step, model=model) step += 1 - if step > max_steps: + if step > max_steps + initial_step - 1: # TODO: max steps should be tracked for each structure in the batch warnings.warn(f"Optimize has reached max steps: {step}", stacklevel=2) break diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index ab841a85..15f47289 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -41,7 +41,6 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState - if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader @@ -261,7 +260,8 @@ def report( and step % self.state_frequency == 0 and self.filenames is not None ): - self.trajectories[idx].write_state(substate, step, **self.state_kwargs) + traj = self.trajectories[idx] + traj.write_state(substate, step, **self.state_kwargs) all_state_props = {} # Process property calculators for this system @@ -302,6 +302,44 @@ def close(self) -> None: for trajectory in self.trajectories: trajectory.close() + @property + def mode(self) -> Literal["r", "w", "a"]: + """Get the mode of the first trajectory file. + + Returns: + "r" | "w" | "a": Mode from the trajectory_kwargs used during initialization. + """ + if not self.trajectories: + raise ValueError("No trajectories loaded.") + # Key is guaranteed to exist because we set it during initialization. + return self.trajectory_kwargs["mode"] + + + @property + def last_step(self) -> int: + """Get the maximum last step across all trajectory files. + + Returns the highest step number found across all trajectory files. + This is useful for resuming optimizations from where they left off. + + Returns: + int: The maximum last step number across all trajectories, or 0 if + no trajectories exist or all are empty + """ + if not self.trajectories: + return 0 + + max_step = 0 + for trajectory in self.trajectories: + if trajectory._file.isopen: + last_step = trajectory.last_step + else: + with TorchSimTrajectory(trajectory._file.filename, mode="r") as traj: + last_step = traj.last_step + max_step = max(max_step, last_step) + + return max_step + def __enter__(self) -> "TrajectoryReporter": """Support the context manager protocol. @@ -594,7 +632,7 @@ def _validate_array(self, name: str, data: np.ndarray, steps: list[int]) -> None ) # Validate step is monotonically increasing by checking HDF5 file directly - steps_node = self._file.get_node("/steps/", name=name) + steps_node = self.get_steps(name) if len(steps_node) > 0: last_step = steps_node[-1] # Get the last recorded step if steps[0] <= last_step: @@ -658,9 +696,6 @@ def get_array( def get_steps( self, name: str, - start: int | None = None, - stop: int | None = None, - step: int = 1, ) -> np.ndarray: """Get the steps for an array. @@ -675,9 +710,29 @@ def get_steps( Returns: np.ndarray: Array of step numbers with shape [n_selected_frames] """ - return self._file.root.steps.__getitem__(name).read( - start=start, stop=stop, step=step - ) + return self._file.get_node("/steps/", name=name).read() + + @property + def last_step(self) -> int: + """Get the last step number from the trajectory. + + Retrieves the maximum step number across all arrays in the trajectory. + If the trajectory is empty or has no arrays, returns 0. + + Returns: + int: The last (maximum) step number in the trajectory, or 0 if empty + """ + if not self.array_registry: + return 0 + + max_step = 0 + for name in self.array_registry: + steps_node = self.get_steps(name) + if len(steps_node) > 0: + last_step = int(steps_node[-1]) + max_step = max(max_step, last_step) + + return max_step def __str__(self) -> str: """Get a string representation of the trajectory. From db4cfa2c1aed33ca3f78d712392b00a31c8cb7b9 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 25 Nov 2025 15:02:23 +0000 Subject: [PATCH 02/29] test append to trajectory --- tests/test_trajectory.py | 96 ++++++++++++++++++++++++++++++++++++++++ torch_sim/runners.py | 2 +- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 7f7252cf..6faf7344 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -1,5 +1,6 @@ import os import sys +import tempfile from collections.abc import Callable, Generator from pathlib import Path @@ -834,3 +835,98 @@ def test_write_ase_trajectory_importerror( with pytest.raises(ImportError, match="ASE is required to convert to ASE trajectory"): traj.write_ase_trajectory(tmp_path / "dummy.traj") traj.close() + + +def test_optimize_append_to_trajectory(si_double_sim_state: SimState, lj_model: LennardJonesModel) -> None: + """Test appending to an existing trajectory when running ts.optimize.""" + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [f"{temp_dir}/optimize_trajectory_{idx}.h5" for idx in range(2)] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + ) + + # First optimization run + opt_state = ts.optimize( + system=si_double_sim_state, + model=lj_model, + max_steps=5, + optimizer=ts.Optimizer.fire, + trajectory_reporter=trajectory_reporter, + steps_between_swaps=100 + ) + + for traj in trajectory_reporter.trajectories: + with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + # Check that the trajectory file has 5 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) + + trajectory_reporter_2 = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + trajectory_kwargs=dict(mode="a") + ) + _ = ts.optimize( + system=opt_state, + model=lj_model, + max_steps=7, + optimizer=ts.Optimizer.fire, + trajectory_reporter=trajectory_reporter_2, + steps_between_swaps=100 + ) + for traj in trajectory_reporter_2.trajectories: + with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + # Check that the trajectory file now has 12 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) + +def test_integrate_append_to_trajectory(si_double_sim_state: SimState, lj_model: LennardJonesModel) -> None: + """Test appending to an existing trajectory when running ts.integrate.""" + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [f"{temp_dir}/integrate_trajectory_{idx}.h5" for idx in range(2)] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + ) + + # First integration run + int_state = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.1, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + for traj in trajectory_reporter.trajectories: + with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + # Check that the trajectory file has 5 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) + + trajectory_reporter_2 = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + trajectory_kwargs=dict(mode="a") + ) + _ = ts.integrate( + system=int_state, + model=lj_model, + timestep=0.1, + temperature=300.0, + n_steps=7, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter_2, + ) + for traj in trajectory_reporter_2.trajectories: + with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + # Check that the trajectory file now has 12 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) \ No newline at end of file diff --git a/torch_sim/runners.py b/torch_sim/runners.py index f6b72ea5..58e23bd5 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -206,7 +206,7 @@ def integrate[T: SimState]( # noqa: C901 # run the simulation for step in range(initial_step, initial_step + n_steps): state = step_func( - state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs + state=state, model=model, dt=dt, kT=kTs[step - initial_step], **integrator_kwargs ) if trajectory_reporter: From 6f78f542b7810197f96ec3a56cf52fe8bbbd4b08 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 25 Nov 2025 15:11:59 +0000 Subject: [PATCH 03/29] style --- tests/test_trajectory.py | 23 ++++++++++++----------- torch_sim/runners.py | 6 +++++- torch_sim/trajectory.py | 5 ++--- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 6faf7344..ded414d4 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -837,7 +837,9 @@ def test_write_ase_trajectory_importerror( traj.close() -def test_optimize_append_to_trajectory(si_double_sim_state: SimState, lj_model: LennardJonesModel) -> None: +def test_optimize_append_to_trajectory( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: """Test appending to an existing trajectory when running ts.optimize.""" # Create a temporary trajectory file @@ -857,7 +859,7 @@ def test_optimize_append_to_trajectory(si_double_sim_state: SimState, lj_model: max_steps=5, optimizer=ts.Optimizer.fire, trajectory_reporter=trajectory_reporter, - steps_between_swaps=100 + steps_between_swaps=100, ) for traj in trajectory_reporter.trajectories: @@ -866,9 +868,7 @@ def test_optimize_append_to_trajectory(si_double_sim_state: SimState, lj_model: np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) trajectory_reporter_2 = ts.TrajectoryReporter( - traj_files, - state_frequency=1, - trajectory_kwargs=dict(mode="a") + traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") ) _ = ts.optimize( system=opt_state, @@ -876,14 +876,17 @@ def test_optimize_append_to_trajectory(si_double_sim_state: SimState, lj_model: max_steps=7, optimizer=ts.Optimizer.fire, trajectory_reporter=trajectory_reporter_2, - steps_between_swaps=100 + steps_between_swaps=100, ) for traj in trajectory_reporter_2.trajectories: with TorchSimTrajectory(traj._file.filename, mode="r") as traj: # Check that the trajectory file now has 12 frames np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) -def test_integrate_append_to_trajectory(si_double_sim_state: SimState, lj_model: LennardJonesModel) -> None: + +def test_integrate_append_to_trajectory( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: """Test appending to an existing trajectory when running ts.integrate.""" # Create a temporary trajectory file @@ -913,9 +916,7 @@ def test_integrate_append_to_trajectory(si_double_sim_state: SimState, lj_model: np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) trajectory_reporter_2 = ts.TrajectoryReporter( - traj_files, - state_frequency=1, - trajectory_kwargs=dict(mode="a") + traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") ) _ = ts.integrate( system=int_state, @@ -929,4 +930,4 @@ def test_integrate_append_to_trajectory(si_double_sim_state: SimState, lj_model: for traj in trajectory_reporter_2.trajectories: with TorchSimTrajectory(traj._file.filename, mode="r") as traj: # Check that the trajectory file now has 12 frames - np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) \ No newline at end of file + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 58e23bd5..96af9bfd 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -206,7 +206,11 @@ def integrate[T: SimState]( # noqa: C901 # run the simulation for step in range(initial_step, initial_step + n_steps): state = step_func( - state=state, model=model, dt=dt, kT=kTs[step - initial_step], **integrator_kwargs + state=state, + model=model, + dt=dt, + kT=kTs[step - initial_step], + **integrator_kwargs, ) if trajectory_reporter: diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 15f47289..69fe9b88 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -312,8 +312,7 @@ def mode(self) -> Literal["r", "w", "a"]: if not self.trajectories: raise ValueError("No trajectories loaded.") # Key is guaranteed to exist because we set it during initialization. - return self.trajectory_kwargs["mode"] - + return self.trajectory_kwargs["mode"] @property def last_step(self) -> int: @@ -711,7 +710,7 @@ def get_steps( np.ndarray: Array of step numbers with shape [n_selected_frames] """ return self._file.get_node("/steps/", name=name).read() - + @property def last_step(self) -> int: """Get the last step number from the trajectory. From b378c6e31bfe30dabc026509595536b271f603a1 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 25 Nov 2025 15:21:04 +0000 Subject: [PATCH 04/29] revert small change --- torch_sim/trajectory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 69fe9b88..9636e5db 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -260,8 +260,7 @@ def report( and step % self.state_frequency == 0 and self.filenames is not None ): - traj = self.trajectories[idx] - traj.write_state(substate, step, **self.state_kwargs) + self.trajectories[idx].write_state(substate, step, **self.state_kwargs) all_state_props = {} # Process property calculators for this system From 1feff96131471e222313085e03be64df7991c9a7 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 26 Nov 2025 10:00:33 +0000 Subject: [PATCH 05/29] maintain step per system --- tests/test_trajectory.py | 4 ++-- torch_sim/runners.py | 36 +++++++++++++++++++----------------- torch_sim/trajectory.py | 31 ++++++++++++++----------------- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index ded414d4..87b1c22a 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -903,7 +903,7 @@ def test_integrate_append_to_trajectory( int_state = ts.integrate( system=si_double_sim_state, model=lj_model, - timestep=0.1, + timestep=0.001, n_steps=5, temperature=300.0, integrator=ts.Integrator.nvt_langevin, @@ -921,7 +921,7 @@ def test_integrate_append_to_trajectory( _ = ts.integrate( system=int_state, model=lj_model, - timestep=0.1, + timestep=0.001, temperature=300.0, n_steps=7, integrator=ts.Integrator.nvt_langevin, diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 96af9bfd..7f60e2e7 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -176,11 +176,10 @@ def integrate[T: SimState]( # noqa: C901 properties=["kinetic_energy", "potential_energy", "temperature"], ) # Auto-detect initial step from trajectory files for resuming integration - initial_step: int = 1 + initial_step: torch.LongTensor = torch.full(size=(initial_state.n_systems,), fill_value=1, dtype=torch.long, device=initial_state.device) if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_step = trajectory_reporter.last_step - if last_step > 0: - initial_step = last_step + 1 + last_logged_steps = torch.tensor(trajectory_reporter.last_step, dtype=torch.long, device=initial_state.device) + initial_step = initial_step + last_logged_steps final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -197,24 +196,26 @@ def integrate[T: SimState]( # noqa: C901 state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **integrator_kwargs) # set up trajectory reporters + _initial_step = initial_step if autobatcher and trajectory_reporter is not None and og_filenames is not None: # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( filenames=[og_filenames[i] for i in system_indices] ) + _initial_step = initial_step[system_indices] # run the simulation - for step in range(initial_step, initial_step + n_steps): + for steps_so_far in range(n_steps): state = step_func( state=state, model=model, dt=dt, - kT=kTs[step - initial_step], + kT=kTs[steps_so_far], **integrator_kwargs, ) if trajectory_reporter: - trajectory_reporter.report(state, step, model=model) + trajectory_reporter.report(state, _initial_step + steps_so_far, model=model) # finish the trajectory reporter final_states.append(state) @@ -467,12 +468,11 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) # Auto-detect initial step from trajectory files for resuming optimizations - initial_step: int = 1 + initial_step: torch.LongTensor = torch.full(size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device) if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_step = trajectory_reporter.last_step - if last_step > 0: - initial_step = last_step + 1 - step: int = initial_step + last_logged_steps = torch.tensor(trajectory_reporter.last_step, dtype=torch.long, device=state.device) + initial_step = initial_step + last_logged_steps + steps_so_far = 0 last_energy = None all_converged_states: list[T] = [] @@ -499,7 +499,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 if ( trajectory_reporter is not None and og_filenames is not None - and (step == 1 or len(converged_states) > 0) + and (steps_so_far == 0 or len(converged_states) > 0) ): mode_before = trajectory_reporter.trajectory_kwargs["mode"] # temporarily set to "append" mode to avoid overwriting existing files @@ -508,6 +508,8 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 filenames=[og_filenames[i] for i in autobatcher.current_idx] ) trajectory_reporter.trajectory_kwargs["mode"] = mode_before + # Remove initial_step entries for converged states + initial_step = initial_step[autobatcher.current_idx] for _step in range(steps_between_swaps): if hasattr(state, "energy"): @@ -516,11 +518,11 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 state = step_fn(state=state, model=model, **optimizer_kwargs) if trajectory_reporter: - trajectory_reporter.report(state, step, model=model) - step += 1 - if step > max_steps + initial_step - 1: + trajectory_reporter.report(state, initial_step + steps_so_far, model=model) + steps_so_far += 1 + if steps_so_far >= max_steps: # TODO: max steps should be tracked for each structure in the batch - warnings.warn(f"Optimize has reached max steps: {step}", stacklevel=2) + warnings.warn(f"Optimize has reached max steps: {steps_so_far}", stacklevel=2) break convergence_tensor = convergence_fn(state, last_energy) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 9636e5db..4cf39aaf 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -212,7 +212,7 @@ def _add_model_arg_to_prop_calculators(self) -> None: self.prop_calculators[frequency][name] = new_fn def report( - self, state: SimState, step: int, model: ModelInterface | None = None + self, state: SimState, step: torch.LongTensor, model: ModelInterface | None = None ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. @@ -223,7 +223,7 @@ def report( Args: state (SimState): Current system state with n_systems equal to len(filenames) - step (int): Current simulation step, setting step to 0 will write + step (torch.LongTensor): Current simulation step per system, setting step to 0 will write the state and all properties. model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators @@ -254,18 +254,19 @@ def report( all_props: list[dict[str, torch.Tensor]] = [] # Process each system separately for idx, substate in enumerate(split_states): + _step = step[idx].item() # Write state to trajectory if it's time if ( self.state_frequency - and step % self.state_frequency == 0 + and _step % self.state_frequency == 0 and self.filenames is not None ): - self.trajectories[idx].write_state(substate, step, **self.state_kwargs) + self.trajectories[idx].write_state(substate, _step, **self.state_kwargs) all_state_props = {} # Process property calculators for this system for report_frequency, calculators in self.prop_calculators.items(): - if step % report_frequency != 0 or report_frequency == 0: + if _step % report_frequency != 0 or report_frequency == 0: continue # Calculate properties for this substate @@ -314,29 +315,25 @@ def mode(self) -> Literal["r", "w", "a"]: return self.trajectory_kwargs["mode"] @property - def last_step(self) -> int: - """Get the maximum last step across all trajectory files. + def last_step(self) -> list[int]: + """Get the last logged step across all trajectory files. - Returns the highest step number found across all trajectory files. This is useful for resuming optimizations from where they left off. Returns: - int: The maximum last step number across all trajectories, or 0 if + list[int]: The last step number for each trajectory, or 0 if no trajectories exist or all are empty """ if not self.trajectories: - return 0 - - max_step = 0 + return [] + last_steps = [] for trajectory in self.trajectories: if trajectory._file.isopen: - last_step = trajectory.last_step + last_steps.append(trajectory.last_step) else: with TorchSimTrajectory(trajectory._file.filename, mode="r") as traj: - last_step = traj.last_step - max_step = max(max_step, last_step) - - return max_step + last_steps.append(traj.last_step) + return last_steps def __enter__(self) -> "TrajectoryReporter": """Support the context manager protocol. From eed1fc17c43dea048c613a24ae1ae93526a5e4cd Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 26 Nov 2025 10:05:46 +0000 Subject: [PATCH 06/29] format --- torch_sim/runners.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 7f60e2e7..fefbfe64 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -176,9 +176,16 @@ def integrate[T: SimState]( # noqa: C901 properties=["kinetic_energy", "potential_energy", "temperature"], ) # Auto-detect initial step from trajectory files for resuming integration - initial_step: torch.LongTensor = torch.full(size=(initial_state.n_systems,), fill_value=1, dtype=torch.long, device=initial_state.device) + initial_step: torch.LongTensor = torch.full( + size=(initial_state.n_systems,), + fill_value=1, + dtype=torch.long, + device=initial_state.device, + ) if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_logged_steps = torch.tensor(trajectory_reporter.last_step, dtype=torch.long, device=initial_state.device) + last_logged_steps = torch.tensor( + trajectory_reporter.last_step, dtype=torch.long, device=initial_state.device + ) initial_step = initial_step + last_logged_steps final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -215,7 +222,9 @@ def integrate[T: SimState]( # noqa: C901 ) if trajectory_reporter: - trajectory_reporter.report(state, _initial_step + steps_so_far, model=model) + trajectory_reporter.report( + state, _initial_step + steps_so_far, model=model + ) # finish the trajectory reporter final_states.append(state) @@ -468,9 +477,13 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) # Auto-detect initial step from trajectory files for resuming optimizations - initial_step: torch.LongTensor = torch.full(size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device) + initial_step: torch.LongTensor = torch.full( + size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device + ) if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_logged_steps = torch.tensor(trajectory_reporter.last_step, dtype=torch.long, device=state.device) + last_logged_steps = torch.tensor( + trajectory_reporter.last_step, dtype=torch.long, device=state.device + ) initial_step = initial_step + last_logged_steps steps_so_far = 0 @@ -518,11 +531,15 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 state = step_fn(state=state, model=model, **optimizer_kwargs) if trajectory_reporter: - trajectory_reporter.report(state, initial_step + steps_so_far, model=model) + trajectory_reporter.report( + state, initial_step + steps_so_far, model=model + ) steps_so_far += 1 if steps_so_far >= max_steps: # TODO: max steps should be tracked for each structure in the batch - warnings.warn(f"Optimize has reached max steps: {steps_so_far}", stacklevel=2) + warnings.warn( + f"Optimize has reached max steps: {steps_so_far}", stacklevel=2 + ) break convergence_tensor = convergence_fn(state, last_energy) From 031fc567371460b0e9ec4d119e114ea69f2066cf Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 26 Nov 2025 10:57:05 +0000 Subject: [PATCH 07/29] integrate only for (n_steps - initial_step) steps when continuing --- tests/test_trajectory.py | 17 +++++++++++++++-- torch_sim/runners.py | 32 +++++++++++++++++--------------- torch_sim/trajectory.py | 9 +++++---- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 87b1c22a..1d800b02 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -918,6 +918,19 @@ def test_integrate_append_to_trajectory( trajectory_reporter_2 = ts.TrajectoryReporter( traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") ) + # Nothing to do here, as we already have step 5 in the trajectory + # and n_steps is 5. + state_2 = ts.integrate( + system=int_state, + model=lj_model, + timestep=0.001, + temperature=300.0, + n_steps=5, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter_2, + ) + torch.testing.assert_close(state_2.positions, int_state.positions) + # run two (7 - 5) more steps of integration. _ = ts.integrate( system=int_state, model=lj_model, @@ -929,5 +942,5 @@ def test_integrate_append_to_trajectory( ) for traj in trajectory_reporter_2.trajectories: with TorchSimTrajectory(traj._file.filename, mode="r") as traj: - # Check that the trajectory file now has 12 frames - np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) + # Check that the trajectory file now has 7 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index fefbfe64..10b95825 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -120,7 +120,10 @@ def integrate[T: SimState]( # noqa: C901 model (ModelInterface): Neural network model module integrator (Integrator | tuple): Either a key from Integrator or a tuple of (init_func, step_func) functions. - n_steps (int): Number of integration steps + n_steps (int): Number of integration steps. If resuming from a trajectory, this is the + total number of steps to run, not the number of additional steps. That is, + if the trajectory has 10 steps and n_steps=20, the integrator will run + an additional 10 steps. temperature (float | ArrayLike): Temperature or array of temperatures for each step timestep (float): Integration time step @@ -176,17 +179,18 @@ def integrate[T: SimState]( # noqa: C901 properties=["kinetic_energy", "potential_energy", "temperature"], ) # Auto-detect initial step from trajectory files for resuming integration - initial_step: torch.LongTensor = torch.full( - size=(initial_state.n_systems,), - fill_value=1, - dtype=torch.long, - device=initial_state.device, - ) + initial_step: int = 1 if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_logged_steps = torch.tensor( - trajectory_reporter.last_step, dtype=torch.long, device=initial_state.device + last_logged_steps = trajectory_reporter.last_step + assert len(set(last_logged_steps)) == 1, ( + "All trajectory files must have the same last step for resuming integration." ) - initial_step = initial_step + last_logged_steps + initial_step = initial_step + last_logged_steps[0] + if initial_step >= n_steps: + warnings.warn( + f"Initial step {initial_step} ≥ n_steps {n_steps}. Nothing will be done.", + ) + return initial_state # type: ignore[return-value] final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -203,16 +207,14 @@ def integrate[T: SimState]( # noqa: C901 state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **integrator_kwargs) # set up trajectory reporters - _initial_step = initial_step if autobatcher and trajectory_reporter is not None and og_filenames is not None: # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( filenames=[og_filenames[i] for i in system_indices] ) - _initial_step = initial_step[system_indices] # run the simulation - for steps_so_far in range(n_steps): + for steps_so_far in range(n_steps - initial_step + 1): state = step_func( state=state, model=model, @@ -223,7 +225,7 @@ def integrate[T: SimState]( # noqa: C901 if trajectory_reporter: trajectory_reporter.report( - state, _initial_step + steps_so_far, model=model + state, initial_step + steps_so_far, model=model ) # finish the trajectory reporter @@ -532,7 +534,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 if trajectory_reporter: trajectory_reporter.report( - state, initial_step + steps_so_far, model=model + state, (initial_step + steps_so_far).tolist(), model=model ) steps_so_far += 1 if steps_so_far >= max_steps: diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 4cf39aaf..05ca447b 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -212,7 +212,7 @@ def _add_model_arg_to_prop_calculators(self) -> None: self.prop_calculators[frequency][name] = new_fn def report( - self, state: SimState, step: torch.LongTensor, model: ModelInterface | None = None + self, state: SimState, step: int | list[int], model: ModelInterface | None = None ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. @@ -223,8 +223,9 @@ def report( Args: state (SimState): Current system state with n_systems equal to len(filenames) - step (torch.LongTensor): Current simulation step per system, setting step to 0 will write - the state and all properties. + step (int | list[int]): Current simulation step per system, setting step to 0 will write + the state and all properties. If a list is provided, it must have length equal to n_systems. + Otherwise, a single integer step is broadcast to all systems. model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. @@ -254,7 +255,7 @@ def report( all_props: list[dict[str, torch.Tensor]] = [] # Process each system separately for idx, substate in enumerate(split_states): - _step = step[idx].item() + _step = step[idx] if isinstance(step, list) else step # Write state to trajectory if it's time if ( self.state_frequency From 6fd5cc94f799a9517ca7c4a7b249a410412ca9fb Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 26 Nov 2025 16:44:43 +0000 Subject: [PATCH 08/29] change `load_new_trajectories` behavior --- torch_sim/runners.py | 10 +++---- torch_sim/trajectory.py | 59 +++++++++++++++++++++++++++-------------- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index bc822472..3db93849 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -212,7 +212,7 @@ def integrate[T: SimState]( # noqa: C901 # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: # we must remake the trajectory reporter for each system - trajectory_reporter.load_new_trajectories( + trajectory_reporter.reopen_trajectories( filenames=[og_filenames[i] for i in system_indices] ) @@ -519,13 +519,9 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 and og_filenames is not None and (steps_so_far == 0 or len(converged_states) > 0) ): - mode_before = trajectory_reporter.trajectory_kwargs["mode"] - # temporarily set to "append" mode to avoid overwriting existing files - trajectory_reporter.trajectory_kwargs["mode"] = "a" - trajectory_reporter.load_new_trajectories( + trajectory_reporter.reopen_trajectories( filenames=[og_filenames[i] for i in autobatcher.current_idx] ) - trajectory_reporter.trajectory_kwargs["mode"] = mode_before # Remove initial_step entries for converged states initial_step = initial_step[autobatcher.current_idx] @@ -645,7 +641,7 @@ class StaticState(SimState): # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: # we must remake the trajectory reporter for each system - trajectory_reporter.load_new_trajectories( + trajectory_reporter.reopen_trajectories( filenames=[og_filenames[idx] for idx in system_indices] ) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 05ca447b..92a43d01 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -95,7 +95,6 @@ class TrajectoryReporter: state_kwargs: dict[str, Any] metadata: dict[str, str] | None trajectories: list["TorchSimTrajectory"] - filenames: list[str | pathlib.Path] | None def __init__( self, @@ -140,19 +139,35 @@ def __init__( self.metadata = metadata self.trajectories = [] - if filenames is None: - self.filenames = None - else: - self.load_new_trajectories(filenames) + if filenames is not None: + # Initialize trajectories for the first time. Unlike in reopen_trajectories, + # if the user specified "w" mode, we respect that here and start fresh. + self.trajectories = [ + TorchSimTrajectory( + filename=filename, metadata=self.metadata, **self.trajectory_kwargs + ) + for filename in filenames + ] self._add_model_arg_to_prop_calculators() - def load_new_trajectories( + @property + def filenames(self) -> list[str | None]: + """Get the list of trajectory filenames. + + Returns: + list[str | pathlib.Path] | None: List of trajectory file paths, + or None if no trajectories are loaded. + """ + if not self.trajectories: + return None + return [traj._file.filename for traj in self.trajectories] + + def reopen_trajectories( self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] ) -> None: - """Load new trajectories into the reporter. - - Closes any existing trajectory files and initializes new ones. + """ + Closes any existing trajectory files and reopens new ones given by filenames. Args: filenames (str | pathlib.Path | list[str | pathlib.Path]): Path(s) to save @@ -166,19 +181,23 @@ def load_new_trajectories( filenames = ( [filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames) ) - self.filenames = [pathlib.Path(filename) for filename in filenames] - if len(set(self.filenames)) != len(self.filenames): + filenames = [pathlib.Path(filename) for filename in filenames] + if len(set(filenames)) != len(filenames): raise ValueError("All filenames must be unique.") - - self.trajectories = [] - for filename in self.filenames: - self.trajectories.append( - TorchSimTrajectory( - filename=filename, - metadata=self.metadata, - **self.trajectory_kwargs, - ) + # Avoid wiping existing trajectory files when reopening them, hence + # we set to "a" mode temporarily (read mode is unaffected). + _mode = self.trajectory_kwargs.get("mode", "w") + self.trajectory_kwargs["mode"] = "a" if _mode in ["a", "w"] else "r" + self.trajectories = [ + TorchSimTrajectory( + filename=filename, + metadata=self.metadata, + **self.trajectory_kwargs, ) + for filename in filenames + ] + # Restore original mode + self.trajectory_kwargs["mode"] = _mode @property def array_registry(self) -> dict[str, tuple[tuple[int, ...], np.dtype]]: From 4e820ee81598be000debf750572df9edc002a4b8 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 08:32:52 +0000 Subject: [PATCH 09/29] back to `step` --- torch_sim/runners.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 3db93849..d4869685 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -217,18 +217,18 @@ def integrate[T: SimState]( # noqa: C901 ) # run the simulation - for steps_so_far in range(n_steps - initial_step + 1): + for step in range(initial_step, n_steps + 1): state = step_func( state=state, model=model, dt=dt, - kT=kTs[steps_so_far], + kT=kTs[step - initial_step], **integrator_kwargs, ) if trajectory_reporter: trajectory_reporter.report( - state, initial_step + steps_so_far, model=model + state, step, model=model ) # finish the trajectory reporter From 8b8c449f0f99b6fdbb562625900af4c965970ca6 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 08:33:26 +0000 Subject: [PATCH 10/29] format --- torch_sim/runners.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index d4869685..64202db0 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -227,9 +227,7 @@ def integrate[T: SimState]( # noqa: C901 ) if trajectory_reporter: - trajectory_reporter.report( - state, step, model=model - ) + trajectory_reporter.report(state, step, model=model) # finish the trajectory reporter final_states.append(state) From 0a230bfac7811b146c1b60cf78b29967a77deb0b Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 10:31:53 +0000 Subject: [PATCH 11/29] truncate trajectories --- tests/test_trajectory.py | 116 +++++++++++++++++++++++++++++++++++++++ torch_sim/runners.py | 19 +++++-- torch_sim/trajectory.py | 34 ++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 1d800b02..e31a4944 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -944,3 +944,119 @@ def test_integrate_append_to_trajectory( with TorchSimTrajectory(traj._file.filename, mode="r") as traj: # Check that the trajectory file now has 7 frames np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) + + +def test_truncate_trajectory( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """ + Test trajectory.truncate_to_step(). + """ + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [f"{temp_dir}/truncate_trajectory_{idx}.h5" for idx in range(2)] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + prop_calculators={1: {"velocities": lambda state: state.velocities}}, + ) + + # First integration run for 5 steps. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + # Manually remove last two frames from second trajectory to create unevenness + with TorchSimTrajectory(traj_files[1], mode="a") as traj: + traj.truncate_to_step(3) + # Verify that it has 3 frames now. + for array_name in traj.array_registry.keys(): + target_length = 3 + target_steps = [1, 2, 3] + # Special cases: global arrays + if array_name in ["atomic_numbers", "masses"]: + target_length = 1 + target_steps = [0] + if array_name == "pbc": + target_length = 3 + target_steps = [0] + assert len(traj.get_array(array_name)) == target_length + np.testing.assert_allclose(traj.get_steps(array_name), target_steps) + with pytest.raises( + ValueError, + match="Cannot truncate to a step greater than the last step. self.last_step=3 < step=10", + ): + traj.truncate_to_step(10) + with pytest.raises( + ValueError, match="Step must be larger than 0. Got step=0" + ): + traj.truncate_to_step(0) + + +def test_integrate_uneven_trajectory_append( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """ + Test appending to an existing trajectory with uneven frames when running ts.integrate. + Expected behavior: ts.integrate should first truncate all trajectories to the shortest length, + and then append new frames to all trajectories. + """ + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [ + f"{temp_dir}/uneven_integrate_trajectory_{idx}.h5" for idx in range(2) + ] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + prop_calculators={1: {"velocities": lambda state: state.velocities}}, + ) + + # First integration run for 5 steps. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + # Manually remove last two frames from second trajectory to create unevenness + with TorchSimTrajectory(traj_files[1], mode="a") as traj: + traj.truncate_to_step(3) + + trajectory_reporter_2 = ts.TrajectoryReporter( + traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") + ) + # Continue integration for one step, which means we first + # truncate the first trajectory to 3 steps to match the second one, + # and then append 4th step to both. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + temperature=300.0, + n_steps=4, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter_2, + ) + + for traj in trajectory_reporter_2.trajectories: + # both trajectories should have 4 frames now. + with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + expected_steps = [1, 2, 3, 4] + np.testing.assert_allclose(traj.get_steps("positions"), expected_steps) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 64202db0..d1026c28 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -185,15 +185,22 @@ def integrate[T: SimState]( # noqa: C901 initial_step: int = 1 if trajectory_reporter is not None and trajectory_reporter.mode == "a": last_logged_steps = trajectory_reporter.last_step - assert len(set(last_logged_steps)) == 1, ( - "All trajectory files must have the same last step for resuming integration." - ) - initial_step = initial_step + last_logged_steps[0] - if initial_step >= n_steps: + last_logged_step = min(last_logged_steps) + initial_step = initial_step + last_logged_step + if initial_step > n_steps: warnings.warn( - f"Initial step {initial_step} ≥ n_steps {n_steps}. Nothing will be done.", + f"Initial step {initial_step} > n_steps {n_steps}. Nothing will be done.", ) return initial_state # type: ignore[return-value] + if len(set(last_logged_steps)) != 1: + warnings.warn( + "Trajectory files have different last steps. " + "Using the minimum last step for resuming integration." + "This means that some trajectories may be truncated.", + ) + for traj in trajectory_reporter.trajectories: + traj.truncate_to_step(last_logged_step) + final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 92a43d01..e2cd3b87 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -1096,3 +1096,37 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReade traj.close() return Trajectory(filename, mode="r") # Reopen in read mode + + def truncate_to_step(self, step: int) -> None: + """Truncate the trajectory to a specified step. + + Removes frames from the end of the trajectory to reduce its length such that the + last logged step is `step`. + + Args: + step (int): Desired last step of the trajectory after truncation + """ + if self.last_step < step: + raise ValueError( + f"Cannot truncate to a step greater than the last step." + f" {self.last_step=} < {step=}" + ) + elif self.last_step == step: + return # No truncation needed + if step <= 0: + raise ValueError(f"Step must be larger than 0. Got {step=}") + for name in self.array_registry: + steps_node = self._file.get_node("/steps/", name=name) + steps_data = steps_node.read() + if set(steps_data) == set([0]): + continue # skip global arrays + # Find the index where the step is less than or equal to the desired step + # We know that it must be at least one index because of the earlier check. + indices = np.where(steps_data <= step)[0] + length = indices[-1] + 1 # +1 because we want to include this index + + data_node = self._file.get_node("/data/", name=name) + data_node.truncate(length) + steps_node.truncate(length) + + self.flush() From e138a3390a6307f6dc996c85164ebb3abf97fd9a Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 13:30:15 +0000 Subject: [PATCH 12/29] fix tests --- tests/test_trajectory.py | 4 ++-- torch_sim/runners.py | 15 ++++++++------- torch_sim/trajectory.py | 7 ++++++- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index e31a4944..01501e44 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -880,8 +880,8 @@ def test_optimize_append_to_trajectory( ) for traj in trajectory_reporter_2.trajectories: with TorchSimTrajectory(traj._file.filename, mode="r") as traj: - # Check that the trajectory file now has 12 frames - np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) + # Check that the trajectory file now has 7 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) def test_integrate_append_to_trajectory( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index d1026c28..7758be82 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -487,14 +487,14 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) # Auto-detect initial step from trajectory files for resuming optimizations - initial_step: torch.LongTensor = torch.full( + og_initial_step: torch.LongTensor = torch.full( size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device ) if trajectory_reporter is not None and trajectory_reporter.mode == "a": last_logged_steps = torch.tensor( trajectory_reporter.last_step, dtype=torch.long, device=state.device ) - initial_step = initial_step + last_logged_steps + og_initial_step = og_initial_step + last_logged_steps steps_so_far = 0 last_energy = None @@ -517,6 +517,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 break state, converged_states = result all_converged_states.extend(converged_states) + initial_step = og_initial_step[autobatcher.current_idx] # need to update the trajectory reporter if any states have converged if ( @@ -527,8 +528,6 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 trajectory_reporter.reopen_trajectories( filenames=[og_filenames[i] for i in autobatcher.current_idx] ) - # Remove initial_step entries for converged states - initial_step = initial_step[autobatcher.current_idx] for _step in range(steps_between_swaps): if hasattr(state, "energy"): @@ -541,14 +540,16 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 state, (initial_step + steps_so_far).tolist(), model=model ) steps_so_far += 1 - if steps_so_far >= max_steps: - # TODO: max steps should be tracked for each structure in the batch + exceeded_max_steps = (initial_step + steps_so_far) > max_steps + if exceeded_max_steps.all(): warnings.warn( - f"Optimize has reached max steps: {steps_so_far}", stacklevel=2 + f"All systems have reached the maximum number of steps: {max_steps}." ) break convergence_tensor = convergence_fn(state, last_energy) + # Mark states that exceeded max steps as converged to remove them from batch + convergence_tensor = convergence_tensor | exceeded_max_steps if tqdm_pbar: # assume convergence_tensor shape is correct tqdm_pbar.update(torch.count_nonzero(convergence_tensor).item()) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index e2cd3b87..d5b78781 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -140,6 +140,11 @@ def __init__( self.trajectories = [] if filenames is not None: + filenames = ( + [filenames] + if isinstance(filenames, (str, pathlib.Path)) + else list(filenames) + ) # Initialize trajectories for the first time. Unlike in reopen_trajectories, # if the user specified "w" mode, we respect that here and start fresh. self.trajectories = [ @@ -301,7 +306,7 @@ def report( if props: all_state_props.update(props) if self.filenames is not None: - self.trajectories[idx].write_arrays(props, step) + self.trajectories[idx].write_arrays(props, _step) all_props.append(all_state_props) return all_props From 82059d50d47163fbfa927a3b2aef7459358ef64a Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 13:33:14 +0000 Subject: [PATCH 13/29] fix style --- torch_sim/trajectory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index d5b78781..bb13e456 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -41,6 +41,7 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState + if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader @@ -171,8 +172,7 @@ def filenames(self) -> list[str | None]: def reopen_trajectories( self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] ) -> None: - """ - Closes any existing trajectory files and reopens new ones given by filenames. + """Closes any existing trajectory files and reopens new ones given by filenames. Args: filenames (str | pathlib.Path | list[str | pathlib.Path]): Path(s) to save @@ -1116,7 +1116,7 @@ def truncate_to_step(self, step: int) -> None: f"Cannot truncate to a step greater than the last step." f" {self.last_step=} < {step=}" ) - elif self.last_step == step: + if self.last_step == step: return # No truncation needed if step <= 0: raise ValueError(f"Step must be larger than 0. Got {step=}") From 2e4bb06349a2672a130b86f64c4f384087774bd8 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 13:40:10 +0000 Subject: [PATCH 14/29] style --- tests/test_trajectory.py | 23 +++++++++++++---------- torch_sim/runners.py | 13 ++++++++----- torch_sim/trajectory.py | 18 ++++++++++++++---- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 01501e44..1facdde1 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -863,7 +863,7 @@ def test_optimize_append_to_trajectory( ) for traj in trajectory_reporter.trajectories: - with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + with TorchSimTrajectory(traj.filename, mode="r") as traj: # Check that the trajectory file has 5 frames np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) @@ -879,7 +879,7 @@ def test_optimize_append_to_trajectory( steps_between_swaps=100, ) for traj in trajectory_reporter_2.trajectories: - with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + with TorchSimTrajectory(traj.filename, mode="r") as traj: # Check that the trajectory file now has 7 frames np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) @@ -911,7 +911,7 @@ def test_integrate_append_to_trajectory( ) for traj in trajectory_reporter.trajectories: - with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + with TorchSimTrajectory(traj.filename, mode="r") as traj: # Check that the trajectory file has 5 frames np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) @@ -941,7 +941,7 @@ def test_integrate_append_to_trajectory( trajectory_reporter=trajectory_reporter_2, ) for traj in trajectory_reporter_2.trajectories: - with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + with TorchSimTrajectory(traj.filename, mode="r") as traj: # Check that the trajectory file now has 7 frames np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) @@ -979,7 +979,7 @@ def test_truncate_trajectory( with TorchSimTrajectory(traj_files[1], mode="a") as traj: traj.truncate_to_step(3) # Verify that it has 3 frames now. - for array_name in traj.array_registry.keys(): + for array_name in traj.array_registry: target_length = 3 target_steps = [1, 2, 3] # Special cases: global arrays @@ -993,7 +993,10 @@ def test_truncate_trajectory( np.testing.assert_allclose(traj.get_steps(array_name), target_steps) with pytest.raises( ValueError, - match="Cannot truncate to a step greater than the last step. self.last_step=3 < step=10", + match=( + "Cannot truncate to a step greater than the last step. " + "self.last_step=3 < step=10" + ), ): traj.truncate_to_step(10) with pytest.raises( @@ -1006,9 +1009,9 @@ def test_integrate_uneven_trajectory_append( si_double_sim_state: SimState, lj_model: LennardJonesModel ) -> None: """ - Test appending to an existing trajectory with uneven frames when running ts.integrate. - Expected behavior: ts.integrate should first truncate all trajectories to the shortest length, - and then append new frames to all trajectories. + Test appending to an existing trajectory with uneven frames running ts.integrate. + Expected behavior: ts.integrate should first truncate all trajectories to the shortest + length, and then append new frames to all trajectories. """ # Create a temporary trajectory file @@ -1057,6 +1060,6 @@ def test_integrate_uneven_trajectory_append( for traj in trajectory_reporter_2.trajectories: # both trajectories should have 4 frames now. - with TorchSimTrajectory(traj._file.filename, mode="r") as traj: + with TorchSimTrajectory(traj.filename, mode="r") as traj: expected_steps = [1, 2, 3, 4] np.testing.assert_allclose(traj.get_steps("positions"), expected_steps) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 7758be82..016ef344 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -121,10 +121,10 @@ def integrate[T: SimState]( # noqa: C901 model (ModelInterface): Neural network model module integrator (Integrator | tuple): Either a key from Integrator or a tuple of (init_func, step_func) functions. - n_steps (int): Number of integration steps. If resuming from a trajectory, this is the - total number of steps to run, not the number of additional steps. That is, - if the trajectory has 10 steps and n_steps=20, the integrator will run - an additional 10 steps. + n_steps (int): Number of integration steps. If resuming from a trajectory, this + is the total number of steps to run, not the number of additional steps. + That is, if the trajectory has 10 steps and n_steps=20, the integrator will + run an additional 10 steps. temperature (float | ArrayLike): Temperature or array of temperatures for each step timestep (float): Integration time step @@ -190,6 +190,7 @@ def integrate[T: SimState]( # noqa: C901 if initial_step > n_steps: warnings.warn( f"Initial step {initial_step} > n_steps {n_steps}. Nothing will be done.", + stacklevel=2, ) return initial_state # type: ignore[return-value] if len(set(last_logged_steps)) != 1: @@ -197,6 +198,7 @@ def integrate[T: SimState]( # noqa: C901 "Trajectory files have different last steps. " "Using the minimum last step for resuming integration." "This means that some trajectories may be truncated.", + stacklevel=2, ) for traj in trajectory_reporter.trajectories: traj.truncate_to_step(last_logged_step) @@ -543,7 +545,8 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 exceeded_max_steps = (initial_step + steps_so_far) > max_steps if exceeded_max_steps.all(): warnings.warn( - f"All systems have reached the maximum number of steps: {max_steps}." + f"All systems have reached the maximum number of steps: {max_steps}.", + stacklevel=2, ) break diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index bb13e456..0d81bb79 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -247,9 +247,10 @@ def report( Args: state (SimState): Current system state with n_systems equal to len(filenames) - step (int | list[int]): Current simulation step per system, setting step to 0 will write - the state and all properties. If a list is provided, it must have length equal to n_systems. - Otherwise, a single integer step is broadcast to all systems. + step (int | list[int]): Current simulation step per system, setting step + to 0 will write the state and all properties. If a list is provided, it + must have length equal to n_systems. Otherwise, a single integer step + is broadcast to all systems. model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. @@ -661,6 +662,15 @@ def _validate_array(self, name: str, data: np.ndarray, steps: list[int]) -> None f"step {last_step} for array {name}" ) + @property + def filename(self) -> str: + """Get the filename of the trajectory file. + + Returns: + str: Path to the HDF5 file + """ + return self._file.filename + def _serialize_array(self, name: str, data: np.ndarray, steps: list[int]) -> None: """Add additional contents to an array already in the registry. @@ -1123,7 +1133,7 @@ def truncate_to_step(self, step: int) -> None: for name in self.array_registry: steps_node = self._file.get_node("/steps/", name=name) steps_data = steps_node.read() - if set(steps_data) == set([0]): + if set(steps_data) == {0}: continue # skip global arrays # Find the index where the step is less than or equal to the desired step # We know that it must be at least one index because of the earlier check. From 65e61d0e61910af0a3b5a15cf5082d3462a8293e Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 13:45:25 +0000 Subject: [PATCH 15/29] fix type hint --- torch_sim/trajectory.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 0d81bb79..f81b0aa6 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -41,7 +41,6 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState - if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader @@ -158,16 +157,16 @@ def __init__( self._add_model_arg_to_prop_calculators() @property - def filenames(self) -> list[str | None]: + def filenames(self) -> list[str] | None: """Get the list of trajectory filenames. Returns: - list[str | pathlib.Path] | None: List of trajectory file paths, + list[str] | None: List of trajectory file paths, or None if no trajectories are loaded. """ if not self.trajectories: return None - return [traj._file.filename for traj in self.trajectories] + return [traj.filename for traj in self.trajectories] def reopen_trajectories( self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] From c1065e746f5c1b9a7233e37e72181e153c5e2340 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 27 Nov 2025 13:45:48 +0000 Subject: [PATCH 16/29] style --- torch_sim/trajectory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index f81b0aa6..19aa4502 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -41,6 +41,7 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState + if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader From 38619afd4019ca9c275764efcc88aada645fc7db Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Fri, 28 Nov 2025 15:36:45 +0000 Subject: [PATCH 17/29] fix kT indexing in integrate, step counting in optimize --- torch_sim/runners.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 016ef344..0c0190e3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -231,7 +231,7 @@ def integrate[T: SimState]( # noqa: C901 state=state, model=model, dt=dt, - kT=kTs[step - initial_step], + kT=kTs[step - 1], **integrator_kwargs, ) @@ -498,6 +498,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) og_initial_step = og_initial_step + last_logged_steps steps_so_far = 0 + step = og_initial_step.clone() last_energy = None all_converged_states: list[T] = [] @@ -519,7 +520,6 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 break state, converged_states = result all_converged_states.extend(converged_states) - initial_step = og_initial_step[autobatcher.current_idx] # need to update the trajectory reporter if any states have converged if ( @@ -534,15 +534,14 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 for _step in range(steps_between_swaps): if hasattr(state, "energy"): last_energy = state.energy - state = step_fn(state=state, model=model, **optimizer_kwargs) if trajectory_reporter: trajectory_reporter.report( - state, (initial_step + steps_so_far).tolist(), model=model + state, step[autobatcher.current_idx].tolist(), model=model ) - steps_so_far += 1 - exceeded_max_steps = (initial_step + steps_so_far) > max_steps + step[autobatcher.current_idx] += 1 + exceeded_max_steps = step > max_steps if exceeded_max_steps.all(): warnings.warn( f"All systems have reached the maximum number of steps: {max_steps}.", @@ -552,7 +551,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 convergence_tensor = convergence_fn(state, last_energy) # Mark states that exceeded max steps as converged to remove them from batch - convergence_tensor = convergence_tensor | exceeded_max_steps + convergence_tensor = convergence_tensor | exceeded_max_steps[autobatcher.current_idx] if tqdm_pbar: # assume convergence_tensor shape is correct tqdm_pbar.update(torch.count_nonzero(convergence_tensor).item()) From dde853cf6a08264e34511763d4ecb270bd392681 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Fri, 28 Nov 2025 15:37:28 +0000 Subject: [PATCH 18/29] style --- torch_sim/runners.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 0c0190e3..37f24226 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -551,7 +551,9 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 convergence_tensor = convergence_fn(state, last_energy) # Mark states that exceeded max steps as converged to remove them from batch - convergence_tensor = convergence_tensor | exceeded_max_steps[autobatcher.current_idx] + convergence_tensor = ( + convergence_tensor | exceeded_max_steps[autobatcher.current_idx] + ) if tqdm_pbar: # assume convergence_tensor shape is correct tqdm_pbar.update(torch.count_nonzero(convergence_tensor).item()) From 787a967693f9fa2ca7e7557a5ba596fb4f22df51 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 10 Dec 2025 14:23:47 +0000 Subject: [PATCH 19/29] rename variable --- torch_sim/trajectory.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 19aa4502..7b4fe136 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -41,7 +41,6 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState - if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader @@ -280,19 +279,21 @@ def report( all_props: list[dict[str, torch.Tensor]] = [] # Process each system separately for idx, substate in enumerate(split_states): - _step = step[idx] if isinstance(step, list) else step + sys_step = step[idx] if isinstance(step, list) else step # Write state to trajectory if it's time if ( self.state_frequency - and _step % self.state_frequency == 0 + and sys_step % self.state_frequency == 0 and self.filenames is not None ): - self.trajectories[idx].write_state(substate, _step, **self.state_kwargs) + self.trajectories[idx].write_state( + substate, sys_step, **self.state_kwargs + ) all_state_props = {} # Process property calculators for this system for report_frequency, calculators in self.prop_calculators.items(): - if _step % report_frequency != 0 or report_frequency == 0: + if sys_step % report_frequency != 0 or report_frequency == 0: continue # Calculate properties for this substate @@ -307,7 +308,7 @@ def report( if props: all_state_props.update(props) if self.filenames is not None: - self.trajectories[idx].write_arrays(props, _step) + self.trajectories[idx].write_arrays(props, sys_step) all_props.append(all_state_props) return all_props From 9c637f52bffe041992fb9c4da129f57d6a7e08d9 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 10 Dec 2025 14:26:28 +0000 Subject: [PATCH 20/29] return positions last step --- torch_sim/trajectory.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 7b4fe136..faae2891 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -747,23 +747,15 @@ def get_steps( def last_step(self) -> int: """Get the last step number from the trajectory. - Retrieves the maximum step number across all arrays in the trajectory. - If the trajectory is empty or has no arrays, returns 0. + Retrieves the last time step recorded in the trajectory based on the "positions" array. Returns: - int: The last (maximum) step number in the trajectory, or 0 if empty + int: The last recorded step number, or 0 if no data exists """ if not self.array_registry: return 0 - max_step = 0 - for name in self.array_registry: - steps_node = self.get_steps(name) - if len(steps_node) > 0: - last_step = int(steps_node[-1]) - max_step = max(max_step, last_step) - - return max_step + return self.get_steps("positions")[-1] def __str__(self) -> str: """Get a string representation of the trajectory. From 0031296c1aa3268b1be2b5af84646a549ecb8357 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 10 Dec 2025 14:31:34 +0000 Subject: [PATCH 21/29] truncate to positions last step --- torch_sim/trajectory.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index faae2891..a47ca6ad 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -30,6 +30,7 @@ import copy import inspect import pathlib +import warnings from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TYPE_CHECKING, Any, Literal, Self @@ -463,6 +464,16 @@ def __init__( self.type_map = self._initialize_type_map( coerce_to_float32=coerce_to_float32, coerce_to_int32=coerce_to_int32 ) + if mode == "a": + inconsistent_step = any( + self.get_steps(name)[-1] > self.last_step for name in self.array_registry + ) + if inconsistent_step: + warnings.warn( + "Inconsistent last steps detected in trajectory arrays. " + "Truncating all arrays to the `positions` array's last step." + ) + self.truncate_to_step(self.last_step) def _initialize_header(self, metadata: dict[str, str] | None = None) -> None: """Initialize the HDF5 file header with metadata. From 1874dc538879be8848abb36e7d25138a0268f0e6 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 10 Dec 2025 14:41:10 +0000 Subject: [PATCH 22/29] extract methods --- torch_sim/runners.py | 95 ++++++++++++++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 29 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 37f24226..d14dd58e 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -100,6 +100,68 @@ def _configure_batches_iterator( return batches +def _determine_initial_step_for_integrate( + trajectory_reporter: TrajectoryReporter | None, + n_steps: int, +) -> int: + """Determine the initial step for resuming integration from trajectory files. + + Args: + trajectory_reporter (TrajectoryReporter | None): The trajectory reporter to + check for resume information + n_steps (int): The total number of steps to run + + Returns: + int: The initial step to start from (1 if not resuming, otherwise last_step + 1) + """ + initial_step: int = 1 + if trajectory_reporter is not None and trajectory_reporter.mode == "a": + last_logged_steps = trajectory_reporter.last_step + last_logged_step = min(last_logged_steps) + initial_step = initial_step + last_logged_step + if initial_step > n_steps: + warnings.warn( + f"Initial step {initial_step} > n_steps {n_steps}. Nothing will be done.", + stacklevel=2, + ) + if len(set(last_logged_steps)) != 1: + warnings.warn( + "Trajectory files have different last steps. " + "Using the minimum last step for resuming integration." + "This means that some trajectories may be truncated.", + stacklevel=2, + ) + for traj in trajectory_reporter.trajectories: + traj.truncate_to_step(last_logged_step) + return initial_step + + +def _determine_initial_step_for_optimize( + trajectory_reporter: TrajectoryReporter | None, + state: SimState, +) -> torch.LongTensor: + """Determine the initial steps for resuming optimization from trajectory files. + + Args: + trajectory_reporter (TrajectoryReporter | None): The trajectory reporter to + check for resume information + state (SimState): The state being optimized + + Returns: + torch.LongTensor: Tensor of initial steps for each system (1 if not resuming, + otherwise last_step + 1 for each system) + """ + initial_step: torch.LongTensor = torch.full( + size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device + ) + if trajectory_reporter is not None and trajectory_reporter.mode == "a": + last_logged_steps = torch.tensor( + trajectory_reporter.last_step, dtype=torch.long, device=state.device + ) + initial_step = initial_step + last_logged_steps + return initial_step + + def integrate[T: SimState]( # noqa: C901 system: StateLike, model: ModelInterface, @@ -182,26 +244,9 @@ def integrate[T: SimState]( # noqa: C901 properties=["kinetic_energy", "potential_energy", "temperature"], ) # Auto-detect initial step from trajectory files for resuming integration - initial_step: int = 1 - if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_logged_steps = trajectory_reporter.last_step - last_logged_step = min(last_logged_steps) - initial_step = initial_step + last_logged_step - if initial_step > n_steps: - warnings.warn( - f"Initial step {initial_step} > n_steps {n_steps}. Nothing will be done.", - stacklevel=2, - ) - return initial_state # type: ignore[return-value] - if len(set(last_logged_steps)) != 1: - warnings.warn( - "Trajectory files have different last steps. " - "Using the minimum last step for resuming integration." - "This means that some trajectories may be truncated.", - stacklevel=2, - ) - for traj in trajectory_reporter.trajectories: - traj.truncate_to_step(last_logged_step) + initial_step = _determine_initial_step_for_integrate(trajectory_reporter, n_steps) + if initial_step > n_steps: + return initial_state # type: ignore[return-value] final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -489,16 +534,8 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) # Auto-detect initial step from trajectory files for resuming optimizations - og_initial_step: torch.LongTensor = torch.full( - size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device - ) - if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_logged_steps = torch.tensor( - trajectory_reporter.last_step, dtype=torch.long, device=state.device - ) - og_initial_step = og_initial_step + last_logged_steps + step = _determine_initial_step_for_optimize(trajectory_reporter, state) steps_so_far = 0 - step = og_initial_step.clone() last_energy = None all_converged_states: list[T] = [] From 51ca30d0c8ffdce7bb198df42b9405e25ff9dbbd Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 10 Dec 2025 16:09:46 +0000 Subject: [PATCH 23/29] fix tests --- torch_sim/trajectory.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index a47ca6ad..e3596928 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -763,10 +763,9 @@ def last_step(self) -> int: Returns: int: The last recorded step number, or 0 if no data exists """ - if not self.array_registry: + if not self.array_registry or "positions" not in self.array_registry: return 0 - - return self.get_steps("positions")[-1] + return self.get_steps("positions")[-1].item() def __str__(self) -> str: """Get a string representation of the trajectory. From dfd39f06bd7db944514f4a623a493251e0f150da Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 10 Dec 2025 16:14:41 +0000 Subject: [PATCH 24/29] format --- torch_sim/trajectory.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index e3596928..04d3d39c 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -42,6 +42,7 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState + if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader @@ -471,7 +472,8 @@ def __init__( if inconsistent_step: warnings.warn( "Inconsistent last steps detected in trajectory arrays. " - "Truncating all arrays to the `positions` array's last step." + "Truncating all arrays to the `positions` array's last step.", + stacklevel=2, ) self.truncate_to_step(self.last_step) @@ -758,7 +760,8 @@ def get_steps( def last_step(self) -> int: """Get the last step number from the trajectory. - Retrieves the last time step recorded in the trajectory based on the "positions" array. + Retrieves the last time step recorded in the trajectory based + on the "positions" array. Returns: int: The last recorded step number, or 0 if no data exists From 36d915a0023c3b8ec02e199367a74d57212a89a3 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 30 Dec 2025 11:35:04 +0000 Subject: [PATCH 25/29] prek --- torch_sim/runners.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 6ff4cec8..8b433d56 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -160,6 +160,8 @@ def _determine_initial_step_for_optimize( ) initial_step = initial_step + last_logged_steps return initial_step + + def _normalize_temperature_tensor( temperature: float | list | torch.Tensor, n_steps: int, initial_state: SimState ) -> torch.Tensor: From 2470b46a7107f52da039241ee46d2aed65d18334 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 30 Dec 2025 13:03:56 +0000 Subject: [PATCH 26/29] disable auto truncating --- tests/test_trajectory.py | 96 ++++++++++++++++++++++++++-------------- torch_sim/runners.py | 36 ++++++--------- torch_sim/trajectory.py | 25 +++++++++++ 3 files changed, 102 insertions(+), 55 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 1facdde1..0844628c 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -918,19 +918,7 @@ def test_integrate_append_to_trajectory( trajectory_reporter_2 = ts.TrajectoryReporter( traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") ) - # Nothing to do here, as we already have step 5 in the trajectory - # and n_steps is 5. - state_2 = ts.integrate( - system=int_state, - model=lj_model, - timestep=0.001, - temperature=300.0, - n_steps=5, - integrator=ts.Integrator.nvt_langevin, - trajectory_reporter=trajectory_reporter_2, - ) - torch.testing.assert_close(state_2.positions, int_state.positions) - # run two (7 - 5) more steps of integration. + # run 7 more steps of integration. _ = ts.integrate( system=int_state, model=lj_model, @@ -942,8 +930,8 @@ def test_integrate_append_to_trajectory( ) for traj in trajectory_reporter_2.trajectories: with TorchSimTrajectory(traj.filename, mode="r") as traj: - # Check that the trajectory file now has 7 frames - np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) + # Check that the trajectory file now has 12 (5 + 7) frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) def test_truncate_trajectory( @@ -1005,6 +993,55 @@ def test_truncate_trajectory( traj.truncate_to_step(0) +def test_truncate_trajectory_reporter( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """ + Test TrajectoryReporter.truncate_to_step(). + """ + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [ + f"{temp_dir}/truncate_reporter_trajectory_{idx}.h5" for idx in range(2) + ] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + prop_calculators={1: {"velocities": lambda state: state.velocities}}, + ) + + # First integration run for 5 steps. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + trajectory_reporter.truncate_to_step(step=min(trajectory_reporter.last_step)) + assert trajectory_reporter.last_step == [5, 5] + with pytest.raises( + ValueError, + match=( + "Step 7 is greater than the minimum last step " + r"across trajectories \(5\)\." + ), + ): + trajectory_reporter.truncate_to_step(7) + # try negative number + with pytest.raises(ValueError, match="Step must be greater than 0. Got step=-2"): + trajectory_reporter.truncate_to_step(-2) + # truncate to step 3 + trajectory_reporter.truncate_to_step(3) + assert trajectory_reporter.last_step == [3, 3] + + def test_integrate_uneven_trajectory_append( si_double_sim_state: SimState, lj_model: LennardJonesModel ) -> None: @@ -1045,21 +1082,14 @@ def test_integrate_uneven_trajectory_append( trajectory_reporter_2 = ts.TrajectoryReporter( traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") ) - # Continue integration for one step, which means we first - # truncate the first trajectory to 3 steps to match the second one, - # and then append 4th step to both. - _ = ts.integrate( - system=si_double_sim_state, - model=lj_model, - timestep=0.001, - temperature=300.0, - n_steps=4, - integrator=ts.Integrator.nvt_langevin, - trajectory_reporter=trajectory_reporter_2, - ) - - for traj in trajectory_reporter_2.trajectories: - # both trajectories should have 4 frames now. - with TorchSimTrajectory(traj.filename, mode="r") as traj: - expected_steps = [1, 2, 3, 4] - np.testing.assert_allclose(traj.get_steps("positions"), expected_steps) + # Should raise a ValueError: + with pytest.raises(ValueError): + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + temperature=300.0, + n_steps=4, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter_2, + ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8b433d56..b00a8a34 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -102,14 +102,12 @@ def _configure_batches_iterator( def _determine_initial_step_for_integrate( trajectory_reporter: TrajectoryReporter | None, - n_steps: int, ) -> int: """Determine the initial step for resuming integration from trajectory files. Args: trajectory_reporter (TrajectoryReporter | None): The trajectory reporter to check for resume information - n_steps (int): The total number of steps to run Returns: int: The initial step to start from (1 if not resuming, otherwise last_step + 1) @@ -119,20 +117,18 @@ def _determine_initial_step_for_integrate( last_logged_steps = trajectory_reporter.last_step last_logged_step = min(last_logged_steps) initial_step = initial_step + last_logged_step - if initial_step > n_steps: - warnings.warn( - f"Initial step {initial_step} > n_steps {n_steps}. Nothing will be done.", - stacklevel=2, - ) if len(set(last_logged_steps)) != 1: - warnings.warn( - "Trajectory files have different last steps. " - "Using the minimum last step for resuming integration." - "This means that some trajectories may be truncated.", - stacklevel=2, + raise ValueError( + f"Trajectory files have different last steps: {set(last_logged_steps)} " + "Cannot resume integration from inconsistent states." + "You can truncate the trajectories to the same step using:\n\n" + " trajectory_reporter.truncate_to_step(min(trajectory_reporter.last_step))\n\n" + "before calling integrate again." ) - for traj in trajectory_reporter.trajectories: - traj.truncate_to_step(last_logged_step) + print( + f"Detected existing trajectory with last step {last_logged_step}." + f" Resuming integration from step {initial_step}." + ) return initial_step @@ -254,9 +250,7 @@ def integrate[T: SimState]( # noqa: C901 integrator (Integrator | tuple): Either a key from Integrator or a tuple of (init_func, step_func) functions. n_steps (int): Number of integration steps. If resuming from a trajectory, this - is the total number of steps to run, not the number of additional steps. - That is, if the trajectory has 10 steps and n_steps=20, the integrator will - run an additional 10 steps. + is the number of additional steps to run. temperature (float | ArrayLike): Temperature or array of temperatures for each step or system: Float: used for all steps and systems @@ -310,9 +304,7 @@ def integrate[T: SimState]( # noqa: C901 properties=["kinetic_energy", "potential_energy", "temperature"], ) # Auto-detect initial step from trajectory files for resuming integration - initial_step = _determine_initial_step_for_integrate(trajectory_reporter, n_steps) - if initial_step > n_steps: - return initial_state # type: ignore[return-value] + initial_step = _determine_initial_step_for_integrate(trajectory_reporter) final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -342,12 +334,12 @@ def integrate[T: SimState]( # noqa: C901 ) # run the simulation - for step in range(initial_step, n_steps + 1): + for step in range(initial_step, initial_step + n_steps): state = step_func( state=state, model=model, dt=dt, - kT=batch_kT[step - 1], + kT=batch_kT[step - initial_step], **integrator_kwargs, ) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 04d3d39c..f20a11a7 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -213,6 +213,29 @@ def array_registry(self) -> dict[str, tuple[tuple[int, ...], np.dtype]]: return self.trajectories[0].array_registry return {} + def truncate_to_step(self, step: int) -> None: + """Truncate all trajectory files to the specified step. + **WARNING**: This operation is irreversible and will remove data from + the trajectory files. + + Args: + step (int): The step to truncate to. + """ + if step <= 0: + raise ValueError(f"Step must be greater than 0. Got step={step}.") + if step > min(self.last_step): + raise ValueError( + f"Step {step} is greater than the minimum last step " + f"across trajectories ({min(self.last_step)})." + ) + for trajectory in self.trajectories: + # trajectory file could be closed + if trajectory._file.isopen: + trajectory.truncate_to_step(step) + else: + with TorchSimTrajectory(trajectory.filename, mode="a") as traj: + traj.truncate_to_step(step) + def _add_model_arg_to_prop_calculators(self) -> None: """Add model argument to property calculators that only accept state. @@ -1120,6 +1143,8 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReade def truncate_to_step(self, step: int) -> None: """Truncate the trajectory to a specified step. + **WARNING**: This operation is irreversible and will permanently + modify the trajectory file. Removes frames from the end of the trajectory to reduce its length such that the last logged step is `step`. From 28bb6601a95f57ece4ed018399374db19f57d0c6 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 30 Dec 2025 14:20:47 +0000 Subject: [PATCH 27/29] style --- tests/test_trajectory.py | 4 +++- torch_sim/runners.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 0844628c..88ca67a7 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -1083,7 +1083,9 @@ def test_integrate_uneven_trajectory_append( traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") ) # Should raise a ValueError: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Cannot resume integration from inconsistent states" + ): _ = ts.integrate( system=si_double_sim_state, model=lj_model, diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b00a8a34..7b4a4b04 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -122,12 +122,13 @@ def _determine_initial_step_for_integrate( f"Trajectory files have different last steps: {set(last_logged_steps)} " "Cannot resume integration from inconsistent states." "You can truncate the trajectories to the same step using:\n\n" - " trajectory_reporter.truncate_to_step(min(trajectory_reporter.last_step))\n\n" + " reporter.truncate_to_step(min(reporter.last_step))\n\n" "before calling integrate again." ) - print( + warnings.warn( f"Detected existing trajectory with last step {last_logged_step}." - f" Resuming integration from step {initial_step}." + f" Resuming integration from step {initial_step}.", + stacklevel=2, ) return initial_step From 6de046abd910f6d1f6336de4f056a80024aaccac Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 30 Dec 2025 14:31:02 +0000 Subject: [PATCH 28/29] rename --- tests/test_trajectory.py | 6 +++--- torch_sim/runners.py | 4 ++-- torch_sim/trajectory.py | 7 +++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 88ca67a7..5c4216ef 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -1024,8 +1024,8 @@ def test_truncate_trajectory_reporter( trajectory_reporter=trajectory_reporter, ) - trajectory_reporter.truncate_to_step(step=min(trajectory_reporter.last_step)) - assert trajectory_reporter.last_step == [5, 5] + trajectory_reporter.truncate_to_step(step=min(trajectory_reporter.last_steps)) + assert trajectory_reporter.last_steps == [5, 5] with pytest.raises( ValueError, match=( @@ -1039,7 +1039,7 @@ def test_truncate_trajectory_reporter( trajectory_reporter.truncate_to_step(-2) # truncate to step 3 trajectory_reporter.truncate_to_step(3) - assert trajectory_reporter.last_step == [3, 3] + assert trajectory_reporter.last_steps == [3, 3] def test_integrate_uneven_trajectory_append( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 7b4a4b04..4ba3a100 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -114,7 +114,7 @@ def _determine_initial_step_for_integrate( """ initial_step: int = 1 if trajectory_reporter is not None and trajectory_reporter.mode == "a": - last_logged_steps = trajectory_reporter.last_step + last_logged_steps = trajectory_reporter.last_steps last_logged_step = min(last_logged_steps) initial_step = initial_step + last_logged_step if len(set(last_logged_steps)) != 1: @@ -153,7 +153,7 @@ def _determine_initial_step_for_optimize( ) if trajectory_reporter is not None and trajectory_reporter.mode == "a": last_logged_steps = torch.tensor( - trajectory_reporter.last_step, dtype=torch.long, device=state.device + trajectory_reporter.last_steps, dtype=torch.long, device=state.device ) initial_step = initial_step + last_logged_steps return initial_step diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index f20a11a7..59ace1be 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -42,7 +42,6 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState - if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader @@ -223,10 +222,10 @@ def truncate_to_step(self, step: int) -> None: """ if step <= 0: raise ValueError(f"Step must be greater than 0. Got step={step}.") - if step > min(self.last_step): + if step > min(self.last_steps): raise ValueError( f"Step {step} is greater than the minimum last step " - f"across trajectories ({min(self.last_step)})." + f"across trajectories ({min(self.last_steps)})." ) for trajectory in self.trajectories: # trajectory file could be closed @@ -367,7 +366,7 @@ def mode(self) -> Literal["r", "w", "a"]: return self.trajectory_kwargs["mode"] @property - def last_step(self) -> list[int]: + def last_steps(self) -> list[int]: """Get the last logged step across all trajectory files. This is useful for resuming optimizations from where they left off. From e801817b5fd93a8d66ae607445416b0db19f8aaf Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 30 Dec 2025 14:55:16 +0000 Subject: [PATCH 29/29] prek --- torch_sim/trajectory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 59ace1be..08ccaac6 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -42,6 +42,7 @@ from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState + if TYPE_CHECKING: from ase import Atoms from ase.io.trajectory import TrajectoryReader