diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 997df668..37556aa9 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -419,7 +419,7 @@ def __init__( tol_param: float | None = None, history_size: int | None = None, num_psis_draws: int | None = None, - num_paths: int | None = None, + num_paths: int = 4, max_lbfgs_iters: int | None = None, num_draws: int | None = None, num_elbo_draws: int | None = None, diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 5cea19f2..9124d2f7 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -1328,7 +1328,7 @@ def pathfinder( tol_rel_grad: float | None = None, tol_param: float | None = None, history_size: int | None = None, - num_paths: int | None = None, + num_paths: int = 4, max_lbfgs_iters: int | None = None, draws: int | None = None, num_single_draws: int | None = None, @@ -1352,6 +1352,7 @@ def pathfinder( time_fmt: str = "%Y%m%d%H%M%S", timeout: float | None = None, num_threads: int | None = None, + save_single_paths: bool = False, ) -> CmdStanPathfinder: """ Run CmdStan's Pathfinder variational inference algorithm. @@ -1458,6 +1459,12 @@ def pathfinder( A number other than ``1`` requires the model to have been compiled with STAN_THREADS=True. + :param save_single_paths: Save draws and ELBO evaluations from + individual Pathfinder runs. Draws are saved to CSV files and ELBO + evaluations are saved to JSON files. If ``True``, file paths can be + accessed via ``CmdStanPathfinder.runset.single_path_csv_files`` and + ``CmdStanPathfinder.runset.single_path_json_files``. + :return: A :class:`CmdStanPathfinder` object References @@ -1506,6 +1513,7 @@ def pathfinder( num_elbo_draws=num_elbo_draws, psis_resample=psis_resample, calculate_lp=calculate_lp, + save_single_paths=save_single_paths, ) with temp_single_json(data) as _data, temp_inits(inits) as _inits: diff --git a/cmdstanpy/stanfit/runset.py b/cmdstanpy/stanfit/runset.py index 3764b8c2..89a179f1 100644 --- a/cmdstanpy/stanfit/runset.py +++ b/cmdstanpy/stanfit/runset.py @@ -11,7 +11,7 @@ from time import time from cmdstanpy import _TMPDIR -from cmdstanpy.cmdstan_args import CmdStanArgs, Method +from cmdstanpy.cmdstan_args import CmdStanArgs, Method, PathfinderArgs from cmdstanpy.utils import get_logger @@ -57,6 +57,8 @@ def __init__( self._stdout_files, self._profile_files = [], [] self._csv_files, self._diagnostic_files = [], [] self._config_files = [] + self._single_path_csv_files: list[str] = [] + self._single_path_json_files: list[str] = [] # per-process output files if one_process_per_chain and chains > 1: @@ -101,6 +103,9 @@ def __init__( for id in self._chain_ids ] + if args.method == Method.PATHFINDER: + self.populate_pathfinder_single_path_files() + def __repr__(self) -> str: lines = [ f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, " @@ -222,6 +227,18 @@ def profile_files(self) -> list[str]: """List of paths to CmdStan profiler files.""" return self._profile_files + @property + def single_path_csv_files(self) -> list[str]: + """List of paths to single-path Pathfinder output CSV files. + Only populated when method is Pathfinder and save_single_paths=True""" + return self._single_path_csv_files + + @property + def single_path_json_files(self) -> list[str]: + """List of paths to single-path Pathfinder output ELBO JSON files. + Only populated when method is Pathfinder and save_single_paths=True""" + return self._single_path_json_files + def gen_file_name( self, suffix: str, *, extra: str = "", id: int | None = None ) -> str: @@ -317,3 +334,23 @@ def raise_for_timeouts(self) -> None: f"{sum(self._timeout_flags)} of {self.num_procs} " "processes timed out" ) + + def populate_pathfinder_single_path_files(self) -> None: + """Properly assigns output files for Pathfinder's + save_single_paths=True option""" + if not isinstance(self._args.method_args, PathfinderArgs): + return + if self._args.method_args.save_single_paths: + num_paths = self._args.method_args.num_paths + if num_paths > 1: + self._single_path_csv_files = [ + self.gen_file_name(".csv", extra="path", id=id) + for id in range(1, num_paths + 1) + ] + self._single_path_json_files = [ + self.gen_file_name(".json", extra="path", id=id) + for id in range(1, num_paths + 1) + ] + else: # num_paths == 1 + self._single_path_csv_files = [self.gen_file_name(".csv")] + self._single_path_json_files = [self.gen_file_name(".json")] diff --git a/test/test_pathfinder.py b/test/test_pathfinder.py index 2eb81123..7558eaaf 100644 --- a/test/test_pathfinder.py +++ b/test/test_pathfinder.py @@ -3,6 +3,7 @@ """ import contextlib +import os from io import StringIO from pathlib import Path @@ -193,3 +194,24 @@ def test_pathfinder_threads() -> None: ) pathfinder = bern_model.pathfinder(data=jdata, num_threads=4) assert pathfinder.draws().shape == (1000, 4) + + +def test_pathfinder_single_path_output() -> None: + + stan = DATAFILES_PATH / 'bernoulli.stan' + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = str(DATAFILES_PATH / 'bernoulli.data.json') + + fit = bern_model.pathfinder(data=jdata, num_paths=4, save_single_paths=True) + assert len(fit.runset.single_path_csv_files) == 4 + assert len(fit.runset.single_path_json_files) == 4 + + assert all(os.path.exists(f) for f in fit.runset.single_path_csv_files) + assert all(os.path.exists(f) for f in fit.runset.single_path_json_files) + + fit = bern_model.pathfinder(data=jdata, num_paths=1, save_single_paths=True) + assert len(fit.runset.single_path_csv_files) == 1 + assert len(fit.runset.single_path_json_files) == 1 + + assert all(os.path.exists(f) for f in fit.runset.single_path_csv_files) + assert all(os.path.exists(f) for f in fit.runset.single_path_json_files) diff --git a/test/test_runset.py b/test/test_runset.py index 616b3519..b79d66a4 100644 --- a/test/test_runset.py +++ b/test/test_runset.py @@ -3,7 +3,7 @@ import os from cmdstanpy import _TMPDIR -from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs +from cmdstanpy.cmdstan_args import CmdStanArgs, PathfinderArgs, SamplerArgs from cmdstanpy.stanfit import RunSet from cmdstanpy.utils import EXTENSION @@ -299,3 +299,58 @@ def test_chain_ids() -> None: assert '_11.csv' in runset._csv_files[0] assert 'id=14' in runset.cmd(3) assert '_14.csv' in runset._csv_files[3] + + +def test_output_filenames_pathfinder_single_paths() -> None: + exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + sampler_args = PathfinderArgs(num_paths=4, save_single_paths=True) + chain_ids = [1] + cmdstan_args = CmdStanArgs( + model_name='bernoulli', + model_exe=exe, + chain_ids=chain_ids, + data=jdata, + method_args=sampler_args, + ) + runset = RunSet(args=cmdstan_args) + assert len(runset.single_path_csv_files) == 4 + assert len(runset.single_path_json_files) == 4 + + assert all( + csv_file.endswith(f"_path_{id}.csv") + for id, csv_file in zip(range(1, 5), runset.single_path_csv_files) + ) + assert all( + json_file.endswith(f"_path_{id}.json") + for id, json_file in zip(range(1, 5), runset.single_path_json_files) + ) + + sampler_args = PathfinderArgs(num_paths=1, save_single_paths=True) + cmdstan_args = CmdStanArgs( + model_name='bernoulli', + model_exe=exe, + chain_ids=chain_ids, + data=jdata, + method_args=sampler_args, + ) + runset = RunSet(args=cmdstan_args) + + assert len(runset.single_path_csv_files) == 1 + assert len(runset.single_path_json_files) == 1 + + assert runset.single_path_csv_files[0].endswith(".csv") + assert runset.single_path_json_files[0].endswith(".json") + + sampler_args = PathfinderArgs(num_paths=1, save_single_paths=False) + cmdstan_args = CmdStanArgs( + model_name='bernoulli', + model_exe=exe, + chain_ids=chain_ids, + data=jdata, + method_args=sampler_args, + ) + runset = RunSet(args=cmdstan_args) + + assert len(runset.single_path_csv_files) == 0 + assert len(runset.single_path_json_files) == 0