Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cdb9d9d
Add save_metric=1 to adapt sampler args
amas0 Nov 18, 2025
62fefae
Add metric files to RunSet
amas0 Nov 18, 2025
c7fef6b
Add metric files runset tests
amas0 Nov 18, 2025
838b78c
Add initial metric parsing logic
amas0 Nov 18, 2025
f7a9ae9
Fix string name collisions cmdstan args test
amas0 Nov 18, 2025
c47791b
Lazily load metric info from file
amas0 Dec 2, 2025
4823796
Fix field_validator to be classmethod
amas0 Dec 2, 2025
d23a9d6
Properly handle one process per chain metric output
amas0 Dec 2, 2025
66512d5
Remove _step_size initialization from assemble_draws
amas0 Dec 2, 2025
4a40ef7
Allow stepsize to be nan
amas0 Dec 2, 2025
cb493b5
Short-circuit metric properties to None when fixed param
amas0 Dec 2, 2025
84ae036
Only enable save_metric=1 when adapt engaged
amas0 Dec 2, 2025
063dfb9
Add metric file output for testing CmdStanMCMC construction from outp…
amas0 Dec 2, 2025
2d076af
Merge branch 'stan-dev:develop' into enable-save-metric
amas0 Dec 2, 2025
b138308
Add metric files for runset-big
amas0 Dec 2, 2025
29ddbfd
Fix metric output filenames test to reflect one proc per chain
amas0 Dec 2, 2025
4502aef
Remove functionality and tests for parsing metric info from CSV
amas0 Dec 2, 2025
d80461d
Add pydantic as a dependency
amas0 Dec 2, 2025
0f5dab8
Add tests of MetricInfo validators
amas0 Dec 2, 2025
82a85d2
Remove unused chain_id from MetricInfo
amas0 Dec 12, 2025
1635915
Remove stringified type hints
amas0 Dec 12, 2025
1f0f0a9
Clarify arbitrary_types_allowed usage
amas0 Dec 12, 2025
5fcee74
Remove _metric_info_parsed
amas0 Dec 15, 2025
2497754
Add tests for invalid metric type
amas0 Dec 16, 2025
37080af
Convert MetricInfo.inv_metric to native Python types
amas0 Dec 16, 2025
b5926c9
Fixup mypy issue in tests
amas0 Dec 17, 2025
915eef5
Convert to list for test clarity
amas0 Dec 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
cmd.append(f'window={self.adapt_metric_window}')
if self.adapt_step_size is not None:
cmd.append('term_buffer={}'.format(self.adapt_step_size))
if self.adapt_engaged:
cmd.append('save_metric=1')
# End adapt subsection

if self.num_chains > 1:
cmd.append('num_chains={}'.format(self.num_chains))

Expand Down
6 changes: 3 additions & 3 deletions cmdstanpy/stanfit/gq.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,18 +423,18 @@ def draws_xr(

@overload
def draws_xr(
self: "CmdStanGQ[CmdStanMCMC]",
self: CmdStanGQ[CmdStanMCMC],
vars: str | list[str] | None = None,
inc_warmup: bool = False,
inc_sample: bool = False,
) -> "xr.Dataset": ...
) -> xr.Dataset: ...

def draws_xr(
self,
vars: str | list[str] | None = None,
inc_warmup: bool = False,
inc_sample: bool = False,
) -> "xr.Dataset":
) -> xr.Dataset:
"""
Returns the generated quantities draws as a xarray Dataset.

Expand Down
4 changes: 3 additions & 1 deletion cmdstanpy/stanfit/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Container for the result of running a laplace approximation.
"""

from __future__ import annotations

from typing import Any, Hashable, MutableMapping

import numpy as np
Expand Down Expand Up @@ -197,7 +199,7 @@ def draws_pd(
def draws_xr(
self,
vars: str | list[str] | None = None,
) -> "xr.Dataset":
) -> xr.Dataset:
"""
Returns the sampler draws as a xarray Dataset.

Expand Down
73 changes: 50 additions & 23 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Container for the result of running the sample (MCMC) method
"""

from __future__ import annotations

import math
import os
from io import StringIO
Expand Down Expand Up @@ -31,7 +33,7 @@
stancsv,
)

from .metadata import InferenceMetadata
from .metadata import InferenceMetadata, MetricInfo
from .runset import RunSet


Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
# info from CSV values, instantiated lazily
self._draws: np.ndarray = np.array(())
# only valid when not is_fixed_param
self._metric_type: str | None = None
self._metric: np.ndarray = np.array(())
self._step_size: np.ndarray = np.array(())
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
Expand All @@ -92,6 +95,8 @@ def __init__(
# info from CSV header and initial and final comment blocks
config = self._validate_csv_files()
self._metadata: InferenceMetadata = InferenceMetadata(config)
self._chain_metric_info: list[MetricInfo] = []

if not self._is_fixed_param:
self._check_sampler_diagnostics()

Expand Down Expand Up @@ -216,11 +221,13 @@ def metric_type(self) -> str | None:
to CmdStan arg 'metric'.
When sampler algorithm 'fixed_param' is specified, metric_type is None.
"""
return (
self._metadata.cmdstan_config['metric']
if not self._is_fixed_param
else None
)
if self._is_fixed_param:
return None

if self._metric_type is None:
self._parse_metric_info()

return self._metric_type

@property
def inv_metric(self) -> np.ndarray | None:
Expand All @@ -230,10 +237,15 @@ def inv_metric(self) -> np.ndarray | None:
a ``nchains x nparams x nparams`` array when metric_type is 'dense_e',
or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'.
"""
if self._is_fixed_param or self.metric_type == 'unit_e':
if self._is_fixed_param:
return None

if self._metric_type is None:
self._parse_metric_info()

if self.metric_type == 'unit_e':
return None

self._assemble_draws()
return self._metric

@property
Expand All @@ -242,8 +254,13 @@ def step_size(self) -> np.ndarray | None:
Step size used by sampler for each chain.
When sampler algorithm 'fixed_param' is specified, step size is None.
"""
self._assemble_draws()
return self._step_size if not self._is_fixed_param else None
if self._is_fixed_param:
return None

if self._metric_type is None:
self._parse_metric_info()

return self._step_size

@property
def thin(self) -> int:
Expand Down Expand Up @@ -382,6 +399,27 @@ def _validate_csv_files(self) -> dict[str, Any]:
self._max_treedepths[i] = drest['ct_max_treedepth']
return dzero

def _parse_metric_info(self) -> None:
"""Extracts metric type, inv_metric, and step size information from the
parsed metric JSONs."""
self._chain_metric_info = []
for mf in self.runset.metric_files:
with open(mf) as f:
self._chain_metric_info.append(
MetricInfo.model_validate_json(f.read())
)

metric_types = {cmi.metric_type for cmi in self._chain_metric_info}
if len(metric_types) != 1:
raise ValueError("Inconsistent metric types found across chains")
self._metric_type = self._chain_metric_info[0].metric_type
self._metric = np.asarray(
[cmi.inv_metric for cmi in self._chain_metric_info]
)
self._step_size = np.asarray(
[cmi.stepsize for cmi in self._chain_metric_info]
)

def _check_sampler_diagnostics(self) -> None:
"""
Warn if any iterations ended in divergences or hit maxtreedepth.
Expand Down Expand Up @@ -424,13 +462,11 @@ def _assemble_draws(self) -> None:
dtype=np.float64,
order='F',
)
self._step_size = np.empty(self.chains, dtype=np.float64)

mass_matrix_per_chain = []
for chain in range(self.chains):
try:
(
comments,
_,
header,
draws,
) = stancsv.parse_comments_header_and_draws(
Expand All @@ -443,20 +479,11 @@ def _assemble_draws(self) -> None:
draws_np = np.empty((0, n_cols))

self._draws[:, chain, :] = draws_np
if not self._is_fixed_param:
(
self._step_size[chain],
mass_matrix,
) = stancsv.parse_hmc_adaptation_lines(comments)
mass_matrix_per_chain.append(mass_matrix)
except Exception as exc:
raise ValueError(
f"Parsing output from {self.runset.csv_files[chain]} failed"
) from exc

if all(mm is not None for mm in mass_matrix_per_chain):
self._metric = np.array(mass_matrix_per_chain)

assert self._draws is not None

def summary(
Expand Down Expand Up @@ -652,7 +679,7 @@ def draws_pd(

def draws_xr(
self, vars: str | list[str] | None = None, inc_warmup: bool = False
) -> "xr.Dataset":
) -> xr.Dataset:
"""
Returns the sampler draws as a xarray Dataset.

Expand Down
50 changes: 48 additions & 2 deletions cmdstanpy/stanfit/metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Container for metadata parsed from the output of a CmdStan run"""

from __future__ import annotations

import copy
import math
import os
from typing import Any, Iterator
from typing import Any, Iterator, Literal

import stanio
from pydantic import BaseModel, field_validator, model_validator

from cmdstanpy.utils import stancsv

Expand Down Expand Up @@ -34,7 +38,7 @@ def __init__(
@classmethod
def from_csv(
cls, stan_csv: str | os.PathLike | Iterator[bytes]
) -> 'InferenceMetadata':
) -> InferenceMetadata:
try:
comments, header, _ = stancsv.parse_comments_header_and_draws(
stan_csv
Expand Down Expand Up @@ -79,3 +83,45 @@ def stan_vars(self) -> dict[str, stanio.Variable]:
These are the user-defined variables in the Stan program.
"""
return self._stan_vars


class MetricInfo(BaseModel):
"""Structured representation of HMC-NUTS metric information,
as output by CmdStan"""

stepsize: float
metric_type: Literal["diag_e", "dense_e", "unit_e"]
inv_metric: list[float] | list[list[float]]

@field_validator("stepsize")
@classmethod
def validate_stepsize(cls, v: float) -> float:
if not math.isnan(v) and v <= 0:
raise ValueError("stepsize must be greater than 0 or NaN")
return v

@model_validator(mode="after")
def validate_inv_metric_shape(self) -> MetricInfo:
if not self.inv_metric: # Empty inv_metric, e.g. from no parameters
return self

is_1d = isinstance(self.inv_metric[0], float)

if self.metric_type in ("diag_e", "unit_e") and not is_1d:
raise ValueError(
"inv_metric must be 1D for diag_e and unit_e metric type"
)
if self.metric_type == "dense_e":
if is_1d:
raise ValueError("Dense inv_metric must be 2D")

if any(not row for row in self.inv_metric):
raise ValueError("Dense inv_metric cannot contain empty rows")

n_rows = len(self.inv_metric)
if not all(
len(row) == n_rows for row in self.inv_metric # type: ignore
):
raise ValueError("Dense inv_metric must be square")

return self
24 changes: 24 additions & 0 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
self._stdout_files, self._profile_files = [], []
self._csv_files, self._diagnostic_files = [], []
self._config_files = []
self._metric_files = []

# per-process output files
if one_process_per_chain and chains > 1:
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
# per-chain output files
if chains == 1:
self._csv_files = [self.gen_file_name(".csv")]
if args.method == Method.SAMPLE:
self._metric_files = [
self.gen_file_name(".json", extra="metric")
]
if args.save_latent_dynamics:
self._diagnostic_files = [
self.gen_file_name(".csv", extra="diagnostic")
Expand All @@ -95,6 +100,20 @@ def __init__(
self._csv_files = [
self.gen_file_name(".csv", id=id) for id in self._chain_ids
]
if args.method == Method.SAMPLE:
if one_process_per_chain:
self._metric_files = [
os.path.join(
self._outdir,
f"{self._base_outfile}_{id}_metric.json",
)
for id in self._chain_ids
]
else:
self._metric_files = [
self.gen_file_name(".json", extra="metric", id=id)
for id in self._chain_ids
]
if args.save_latent_dynamics:
self._diagnostic_files = [
self.gen_file_name(".csv", extra="diagnostic", id=id)
Expand Down Expand Up @@ -222,6 +241,11 @@ def profile_files(self) -> list[str]:
"""List of paths to CmdStan profiler files."""
return self._profile_files

@property
def metric_files(self) -> list[str]:
"""List of paths to CmdStan NUTS-HMC sampler metric files."""
return self._metric_files

def gen_file_name(
self, suffix: str, *, extra: str = "", id: int | None = None
) -> str:
Expand Down
4 changes: 3 additions & 1 deletion cmdstanpy/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
CmdStanPy logging
"""

from __future__ import annotations

import functools
import logging
import types
Expand Down Expand Up @@ -39,7 +41,7 @@ def __init__(self, disable: bool) -> None:
def __repr__(self) -> str:
return ""

def __enter__(self) -> "ToggleLogging":
def __enter__(self) -> ToggleLogging:
self.logger.disabled = self.disable
return self

Expand Down
Loading