diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 7f7252cf..5c4216ef 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -1,5 +1,6 @@ import os import sys +import tempfile from collections.abc import Callable, Generator from pathlib import Path @@ -834,3 +835,263 @@ def test_write_ase_trajectory_importerror( with pytest.raises(ImportError, match="ASE is required to convert to ASE trajectory"): traj.write_ase_trajectory(tmp_path / "dummy.traj") traj.close() + + +def test_optimize_append_to_trajectory( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """Test appending to an existing trajectory when running ts.optimize.""" + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [f"{temp_dir}/optimize_trajectory_{idx}.h5" for idx in range(2)] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + ) + + # First optimization run + opt_state = ts.optimize( + system=si_double_sim_state, + model=lj_model, + max_steps=5, + optimizer=ts.Optimizer.fire, + trajectory_reporter=trajectory_reporter, + steps_between_swaps=100, + ) + + for traj in trajectory_reporter.trajectories: + with TorchSimTrajectory(traj.filename, mode="r") as traj: + # Check that the trajectory file has 5 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) + + trajectory_reporter_2 = ts.TrajectoryReporter( + traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") + ) + _ = ts.optimize( + system=opt_state, + model=lj_model, + max_steps=7, + optimizer=ts.Optimizer.fire, + trajectory_reporter=trajectory_reporter_2, + steps_between_swaps=100, + ) + for traj in trajectory_reporter_2.trajectories: + with TorchSimTrajectory(traj.filename, mode="r") as traj: + # Check that the trajectory file now has 7 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 8)) + + +def test_integrate_append_to_trajectory( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """Test appending to an existing trajectory when running ts.integrate.""" + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [f"{temp_dir}/integrate_trajectory_{idx}.h5" for idx in range(2)] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + ) + + # First integration run + int_state = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + for traj in trajectory_reporter.trajectories: + with TorchSimTrajectory(traj.filename, mode="r") as traj: + # Check that the trajectory file has 5 frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6)) + + trajectory_reporter_2 = ts.TrajectoryReporter( + traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") + ) + # run 7 more steps of integration. + _ = ts.integrate( + system=int_state, + model=lj_model, + timestep=0.001, + temperature=300.0, + n_steps=7, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter_2, + ) + for traj in trajectory_reporter_2.trajectories: + with TorchSimTrajectory(traj.filename, mode="r") as traj: + # Check that the trajectory file now has 12 (5 + 7) frames + np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13)) + + +def test_truncate_trajectory( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """ + Test trajectory.truncate_to_step(). + """ + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [f"{temp_dir}/truncate_trajectory_{idx}.h5" for idx in range(2)] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + prop_calculators={1: {"velocities": lambda state: state.velocities}}, + ) + + # First integration run for 5 steps. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + # Manually remove last two frames from second trajectory to create unevenness + with TorchSimTrajectory(traj_files[1], mode="a") as traj: + traj.truncate_to_step(3) + # Verify that it has 3 frames now. + for array_name in traj.array_registry: + target_length = 3 + target_steps = [1, 2, 3] + # Special cases: global arrays + if array_name in ["atomic_numbers", "masses"]: + target_length = 1 + target_steps = [0] + if array_name == "pbc": + target_length = 3 + target_steps = [0] + assert len(traj.get_array(array_name)) == target_length + np.testing.assert_allclose(traj.get_steps(array_name), target_steps) + with pytest.raises( + ValueError, + match=( + "Cannot truncate to a step greater than the last step. " + "self.last_step=3 < step=10" + ), + ): + traj.truncate_to_step(10) + with pytest.raises( + ValueError, match="Step must be larger than 0. Got step=0" + ): + traj.truncate_to_step(0) + + +def test_truncate_trajectory_reporter( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """ + Test TrajectoryReporter.truncate_to_step(). + """ + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [ + f"{temp_dir}/truncate_reporter_trajectory_{idx}.h5" for idx in range(2) + ] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + prop_calculators={1: {"velocities": lambda state: state.velocities}}, + ) + + # First integration run for 5 steps. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + trajectory_reporter.truncate_to_step(step=min(trajectory_reporter.last_steps)) + assert trajectory_reporter.last_steps == [5, 5] + with pytest.raises( + ValueError, + match=( + "Step 7 is greater than the minimum last step " + r"across trajectories \(5\)\." + ), + ): + trajectory_reporter.truncate_to_step(7) + # try negative number + with pytest.raises(ValueError, match="Step must be greater than 0. Got step=-2"): + trajectory_reporter.truncate_to_step(-2) + # truncate to step 3 + trajectory_reporter.truncate_to_step(3) + assert trajectory_reporter.last_steps == [3, 3] + + +def test_integrate_uneven_trajectory_append( + si_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """ + Test appending to an existing trajectory with uneven frames running ts.integrate. + Expected behavior: ts.integrate should first truncate all trajectories to the shortest + length, and then append new frames to all trajectories. + """ + + # Create a temporary trajectory file + with tempfile.TemporaryDirectory() as temp_dir: + traj_files = [ + f"{temp_dir}/uneven_integrate_trajectory_{idx}.h5" for idx in range(2) + ] + + # Initialize model and state + trajectory_reporter = ts.TrajectoryReporter( + traj_files, + state_frequency=1, + prop_calculators={1: {"velocities": lambda state: state.velocities}}, + ) + + # First integration run for 5 steps. + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + n_steps=5, + temperature=300.0, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter, + ) + + # Manually remove last two frames from second trajectory to create unevenness + with TorchSimTrajectory(traj_files[1], mode="a") as traj: + traj.truncate_to_step(3) + + trajectory_reporter_2 = ts.TrajectoryReporter( + traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a") + ) + # Should raise a ValueError: + with pytest.raises( + ValueError, match="Cannot resume integration from inconsistent states" + ): + _ = ts.integrate( + system=si_double_sim_state, + model=lj_model, + timestep=0.001, + temperature=300.0, + n_steps=4, + integrator=ts.Integrator.nvt_langevin, + trajectory_reporter=trajectory_reporter_2, + ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 91f5a340..4ba3a100 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -100,6 +100,65 @@ def _configure_batches_iterator( return batches +def _determine_initial_step_for_integrate( + trajectory_reporter: TrajectoryReporter | None, +) -> int: + """Determine the initial step for resuming integration from trajectory files. + + Args: + trajectory_reporter (TrajectoryReporter | None): The trajectory reporter to + check for resume information + + Returns: + int: The initial step to start from (1 if not resuming, otherwise last_step + 1) + """ + initial_step: int = 1 + if trajectory_reporter is not None and trajectory_reporter.mode == "a": + last_logged_steps = trajectory_reporter.last_steps + last_logged_step = min(last_logged_steps) + initial_step = initial_step + last_logged_step + if len(set(last_logged_steps)) != 1: + raise ValueError( + f"Trajectory files have different last steps: {set(last_logged_steps)} " + "Cannot resume integration from inconsistent states." + "You can truncate the trajectories to the same step using:\n\n" + " reporter.truncate_to_step(min(reporter.last_step))\n\n" + "before calling integrate again." + ) + warnings.warn( + f"Detected existing trajectory with last step {last_logged_step}." + f" Resuming integration from step {initial_step}.", + stacklevel=2, + ) + return initial_step + + +def _determine_initial_step_for_optimize( + trajectory_reporter: TrajectoryReporter | None, + state: SimState, +) -> torch.LongTensor: + """Determine the initial steps for resuming optimization from trajectory files. + + Args: + trajectory_reporter (TrajectoryReporter | None): The trajectory reporter to + check for resume information + state (SimState): The state being optimized + + Returns: + torch.LongTensor: Tensor of initial steps for each system (1 if not resuming, + otherwise last_step + 1 for each system) + """ + initial_step: torch.LongTensor = torch.full( + size=(state.n_systems,), fill_value=1, dtype=torch.long, device=state.device + ) + if trajectory_reporter is not None and trajectory_reporter.mode == "a": + last_logged_steps = torch.tensor( + trajectory_reporter.last_steps, dtype=torch.long, device=state.device + ) + initial_step = initial_step + last_logged_steps + return initial_step + + def _normalize_temperature_tensor( temperature: float | list | torch.Tensor, n_steps: int, initial_state: SimState ) -> torch.Tensor: @@ -191,7 +250,8 @@ def integrate[T: SimState]( # noqa: C901 model (ModelInterface): Neural network model module integrator (Integrator | tuple): Either a key from Integrator or a tuple of (init_func, step_func) functions. - n_steps (int): Number of integration steps + n_steps (int): Number of integration steps. If resuming from a trajectory, this + is the number of additional steps to run. temperature (float | ArrayLike): Temperature or array of temperatures for each step or system: Float: used for all steps and systems @@ -244,6 +304,8 @@ def integrate[T: SimState]( # noqa: C901 trajectory_reporter, properties=["kinetic_energy", "potential_energy", "temperature"], ) + # Auto-detect initial step from trajectory files for resuming integration + initial_step = _determine_initial_step_for_integrate(trajectory_reporter) final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None @@ -268,17 +330,17 @@ def integrate[T: SimState]( # noqa: C901 # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: # we must remake the trajectory reporter for each system - trajectory_reporter.load_new_trajectories( + trajectory_reporter.reopen_trajectories( filenames=[og_filenames[i] for i in system_indices] ) # run the simulation - for step in range(1, n_steps + 1): + for step in range(initial_step, initial_step + n_steps): state = step_func( state=state, model=model, dt=dt, - kT=batch_kT[step - 1], + kT=batch_kT[step - initial_step], **integrator_kwargs, ) @@ -535,7 +597,10 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 trajectory_reporter, properties=["potential_energy"] ) - step: int = 1 + # Auto-detect initial step from trajectory files for resuming optimizations + step = _determine_initial_step_for_optimize(trajectory_reporter, state) + steps_so_far = 0 + last_energy = None all_converged_states: list[T] = [] convergence_tensor = None @@ -561,27 +626,35 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 if ( trajectory_reporter is not None and og_filenames is not None - and (step == 1 or len(converged_states) > 0) + and (steps_so_far == 0 or len(converged_states) > 0) ): - trajectory_reporter.load_new_trajectories( + trajectory_reporter.reopen_trajectories( filenames=[og_filenames[i] for i in autobatcher.current_idx] ) for _step in range(steps_between_swaps): if hasattr(state, "energy"): last_energy = state.energy - state = step_fn(state=state, model=model, **optimizer_kwargs) if trajectory_reporter: - trajectory_reporter.report(state, step, model=model) - step += 1 - if step > max_steps: - # TODO: max steps should be tracked for each structure in the batch - warnings.warn(f"Optimize has reached max steps: {step}", stacklevel=2) + trajectory_reporter.report( + state, step[autobatcher.current_idx].tolist(), model=model + ) + step[autobatcher.current_idx] += 1 + exceeded_max_steps = step > max_steps + if exceeded_max_steps.all(): + warnings.warn( + f"All systems have reached the maximum number of steps: {max_steps}.", + stacklevel=2, + ) break convergence_tensor = convergence_fn(state, last_energy) + # Mark states that exceeded max steps as converged to remove them from batch + convergence_tensor = ( + convergence_tensor | exceeded_max_steps[autobatcher.current_idx] + ) if tqdm_pbar: # assume convergence_tensor shape is correct tqdm_pbar.update(torch.count_nonzero(convergence_tensor).item()) @@ -679,7 +752,7 @@ class StaticState(SimState): # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: # we must remake the trajectory reporter for each system - trajectory_reporter.load_new_trajectories( + trajectory_reporter.reopen_trajectories( filenames=[og_filenames[idx] for idx in system_indices] ) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index ab841a85..08ccaac6 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -30,6 +30,7 @@ import copy import inspect import pathlib +import warnings from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TYPE_CHECKING, Any, Literal, Self @@ -96,7 +97,6 @@ class TrajectoryReporter: state_kwargs: dict[str, Any] metadata: dict[str, str] | None trajectories: list["TorchSimTrajectory"] - filenames: list[str | pathlib.Path] | None def __init__( self, @@ -141,19 +141,39 @@ def __init__( self.metadata = metadata self.trajectories = [] - if filenames is None: - self.filenames = None - else: - self.load_new_trajectories(filenames) + if filenames is not None: + filenames = ( + [filenames] + if isinstance(filenames, (str, pathlib.Path)) + else list(filenames) + ) + # Initialize trajectories for the first time. Unlike in reopen_trajectories, + # if the user specified "w" mode, we respect that here and start fresh. + self.trajectories = [ + TorchSimTrajectory( + filename=filename, metadata=self.metadata, **self.trajectory_kwargs + ) + for filename in filenames + ] self._add_model_arg_to_prop_calculators() - def load_new_trajectories( + @property + def filenames(self) -> list[str] | None: + """Get the list of trajectory filenames. + + Returns: + list[str] | None: List of trajectory file paths, + or None if no trajectories are loaded. + """ + if not self.trajectories: + return None + return [traj.filename for traj in self.trajectories] + + def reopen_trajectories( self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] ) -> None: - """Load new trajectories into the reporter. - - Closes any existing trajectory files and initializes new ones. + """Closes any existing trajectory files and reopens new ones given by filenames. Args: filenames (str | pathlib.Path | list[str | pathlib.Path]): Path(s) to save @@ -167,19 +187,23 @@ def load_new_trajectories( filenames = ( [filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames) ) - self.filenames = [pathlib.Path(filename) for filename in filenames] - if len(set(self.filenames)) != len(self.filenames): + filenames = [pathlib.Path(filename) for filename in filenames] + if len(set(filenames)) != len(filenames): raise ValueError("All filenames must be unique.") - - self.trajectories = [] - for filename in self.filenames: - self.trajectories.append( - TorchSimTrajectory( - filename=filename, - metadata=self.metadata, - **self.trajectory_kwargs, - ) + # Avoid wiping existing trajectory files when reopening them, hence + # we set to "a" mode temporarily (read mode is unaffected). + _mode = self.trajectory_kwargs.get("mode", "w") + self.trajectory_kwargs["mode"] = "a" if _mode in ["a", "w"] else "r" + self.trajectories = [ + TorchSimTrajectory( + filename=filename, + metadata=self.metadata, + **self.trajectory_kwargs, ) + for filename in filenames + ] + # Restore original mode + self.trajectory_kwargs["mode"] = _mode @property def array_registry(self) -> dict[str, tuple[tuple[int, ...], np.dtype]]: @@ -189,6 +213,29 @@ def array_registry(self) -> dict[str, tuple[tuple[int, ...], np.dtype]]: return self.trajectories[0].array_registry return {} + def truncate_to_step(self, step: int) -> None: + """Truncate all trajectory files to the specified step. + **WARNING**: This operation is irreversible and will remove data from + the trajectory files. + + Args: + step (int): The step to truncate to. + """ + if step <= 0: + raise ValueError(f"Step must be greater than 0. Got step={step}.") + if step > min(self.last_steps): + raise ValueError( + f"Step {step} is greater than the minimum last step " + f"across trajectories ({min(self.last_steps)})." + ) + for trajectory in self.trajectories: + # trajectory file could be closed + if trajectory._file.isopen: + trajectory.truncate_to_step(step) + else: + with TorchSimTrajectory(trajectory.filename, mode="a") as traj: + traj.truncate_to_step(step) + def _add_model_arg_to_prop_calculators(self) -> None: """Add model argument to property calculators that only accept state. @@ -213,7 +260,7 @@ def _add_model_arg_to_prop_calculators(self) -> None: self.prop_calculators[frequency][name] = new_fn def report( - self, state: SimState, step: int, model: ModelInterface | None = None + self, state: SimState, step: int | list[int], model: ModelInterface | None = None ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. @@ -224,8 +271,10 @@ def report( Args: state (SimState): Current system state with n_systems equal to len(filenames) - step (int): Current simulation step, setting step to 0 will write - the state and all properties. + step (int | list[int]): Current simulation step per system, setting step + to 0 will write the state and all properties. If a list is provided, it + must have length equal to n_systems. Otherwise, a single integer step + is broadcast to all systems. model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. @@ -255,18 +304,21 @@ def report( all_props: list[dict[str, torch.Tensor]] = [] # Process each system separately for idx, substate in enumerate(split_states): + sys_step = step[idx] if isinstance(step, list) else step # Write state to trajectory if it's time if ( self.state_frequency - and step % self.state_frequency == 0 + and sys_step % self.state_frequency == 0 and self.filenames is not None ): - self.trajectories[idx].write_state(substate, step, **self.state_kwargs) + self.trajectories[idx].write_state( + substate, sys_step, **self.state_kwargs + ) all_state_props = {} # Process property calculators for this system for report_frequency, calculators in self.prop_calculators.items(): - if step % report_frequency != 0 or report_frequency == 0: + if sys_step % report_frequency != 0 or report_frequency == 0: continue # Calculate properties for this substate @@ -281,7 +333,7 @@ def report( if props: all_state_props.update(props) if self.filenames is not None: - self.trajectories[idx].write_arrays(props, step) + self.trajectories[idx].write_arrays(props, sys_step) all_props.append(all_state_props) return all_props @@ -302,6 +354,39 @@ def close(self) -> None: for trajectory in self.trajectories: trajectory.close() + @property + def mode(self) -> Literal["r", "w", "a"]: + """Get the mode of the first trajectory file. + + Returns: + "r" | "w" | "a": Mode from the trajectory_kwargs used during initialization. + """ + if not self.trajectories: + raise ValueError("No trajectories loaded.") + # Key is guaranteed to exist because we set it during initialization. + return self.trajectory_kwargs["mode"] + + @property + def last_steps(self) -> list[int]: + """Get the last logged step across all trajectory files. + + This is useful for resuming optimizations from where they left off. + + Returns: + list[int]: The last step number for each trajectory, or 0 if + no trajectories exist or all are empty + """ + if not self.trajectories: + return [] + last_steps = [] + for trajectory in self.trajectories: + if trajectory._file.isopen: + last_steps.append(trajectory.last_step) + else: + with TorchSimTrajectory(trajectory._file.filename, mode="r") as traj: + last_steps.append(traj.last_step) + return last_steps + def __enter__(self) -> "TrajectoryReporter": """Support the context manager protocol. @@ -403,6 +488,17 @@ def __init__( self.type_map = self._initialize_type_map( coerce_to_float32=coerce_to_float32, coerce_to_int32=coerce_to_int32 ) + if mode == "a": + inconsistent_step = any( + self.get_steps(name)[-1] > self.last_step for name in self.array_registry + ) + if inconsistent_step: + warnings.warn( + "Inconsistent last steps detected in trajectory arrays. " + "Truncating all arrays to the `positions` array's last step.", + stacklevel=2, + ) + self.truncate_to_step(self.last_step) def _initialize_header(self, metadata: dict[str, str] | None = None) -> None: """Initialize the HDF5 file header with metadata. @@ -594,7 +690,7 @@ def _validate_array(self, name: str, data: np.ndarray, steps: list[int]) -> None ) # Validate step is monotonically increasing by checking HDF5 file directly - steps_node = self._file.get_node("/steps/", name=name) + steps_node = self.get_steps(name) if len(steps_node) > 0: last_step = steps_node[-1] # Get the last recorded step if steps[0] <= last_step: @@ -603,6 +699,15 @@ def _validate_array(self, name: str, data: np.ndarray, steps: list[int]) -> None f"step {last_step} for array {name}" ) + @property + def filename(self) -> str: + """Get the filename of the trajectory file. + + Returns: + str: Path to the HDF5 file + """ + return self._file.filename + def _serialize_array(self, name: str, data: np.ndarray, steps: list[int]) -> None: """Add additional contents to an array already in the registry. @@ -658,9 +763,6 @@ def get_array( def get_steps( self, name: str, - start: int | None = None, - stop: int | None = None, - step: int = 1, ) -> np.ndarray: """Get the steps for an array. @@ -675,9 +777,21 @@ def get_steps( Returns: np.ndarray: Array of step numbers with shape [n_selected_frames] """ - return self._file.root.steps.__getitem__(name).read( - start=start, stop=stop, step=step - ) + return self._file.get_node("/steps/", name=name).read() + + @property + def last_step(self) -> int: + """Get the last step number from the trajectory. + + Retrieves the last time step recorded in the trajectory based + on the "positions" array. + + Returns: + int: The last recorded step number, or 0 if no data exists + """ + if not self.array_registry or "positions" not in self.array_registry: + return 0 + return self.get_steps("positions")[-1].item() def __str__(self) -> str: """Get a string representation of the trajectory. @@ -1026,3 +1140,39 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReade traj.close() return Trajectory(filename, mode="r") # Reopen in read mode + + def truncate_to_step(self, step: int) -> None: + """Truncate the trajectory to a specified step. + **WARNING**: This operation is irreversible and will permanently + modify the trajectory file. + + Removes frames from the end of the trajectory to reduce its length such that the + last logged step is `step`. + + Args: + step (int): Desired last step of the trajectory after truncation + """ + if self.last_step < step: + raise ValueError( + f"Cannot truncate to a step greater than the last step." + f" {self.last_step=} < {step=}" + ) + if self.last_step == step: + return # No truncation needed + if step <= 0: + raise ValueError(f"Step must be larger than 0. Got {step=}") + for name in self.array_registry: + steps_node = self._file.get_node("/steps/", name=name) + steps_data = steps_node.read() + if set(steps_data) == {0}: + continue # skip global arrays + # Find the index where the step is less than or equal to the desired step + # We know that it must be at least one index because of the earlier check. + indices = np.where(steps_data <= step)[0] + length = indices[-1] + 1 # +1 because we want to include this index + + data_node = self._file.get_node("/data/", name=name) + data_node.truncate(length) + steps_node.truncate(length) + + self.flush()