Skip to content
Draft
583 changes: 583 additions & 0 deletions notebooks/structural_components_dataclass.ipynb

Large diffs are not rendered by default.

261 changes: 261 additions & 0 deletions pymc_extras/statespace/core/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
from collections.abc import Iterator
from dataclasses import dataclass, fields
from typing import Generic, Self, TypeVar

from pymc_extras.statespace.core import PyMCStateSpace
from pymc_extras.statespace.utils.constants import (
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
OBS_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_AUX_DIM,
SHOCK_DIM,
)


@dataclass(frozen=True)
class Property:
def __str__(self) -> str:
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))


T = TypeVar("T", bound=Property)


@dataclass(frozen=True)
class Info(Generic[T]):
items: tuple[T, ...]
key_field: str = "name"
_index: dict[str, T] | None = None

def __post_init__(self):
index = {}
missing_attr = []
for item in self.items:
if not hasattr(item, self.key_field):
missing_attr.append(item)
continue
key = getattr(item, self.key_field)
if key in index:
raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
index[key] = item
if missing_attr:
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
object.__setattr__(self, "_index", index)

def _key(self, item: T) -> str:
return getattr(item, self.key_field)

def get(self, key: str, default=None) -> T | None:
return self._index.get(key, default)

def __getitem__(self, key: str) -> T:
try:
return self._index[key]
except KeyError as e:
available = ", ".join(self._index.keys())
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e

def __contains__(self, key: object) -> bool:
return key in self._index

def __iter__(self) -> Iterator[str]:
return iter(self.items)

def __len__(self) -> int:
return len(self.items)

def __str__(self) -> str:
return f"{self.key_field}s: {list(self._index.keys())}"

@property
def names(self) -> tuple[str, ...]:
return tuple(self._index.keys())


@dataclass(frozen=True)
class Parameter(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
constraints: str | None = None


@dataclass(frozen=True)
class ParameterInfo(Info[Parameter]):
def __init__(self, parameters: list[Parameter]):
super().__init__(items=tuple(parameters), key_field="name")

def add(self, parameter: Parameter) -> "ParameterInfo":
# return a new ParameterInfo with parameter appended
return ParameterInfo(parameters=[*list(self.items), parameter])

def merge(self, other: "ParameterInfo") -> "ParameterInfo":
"""Combine parameters from two ParameterInfo objects."""
if not isinstance(other, ParameterInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate parameter names found: {overlapping}")

return ParameterInfo(parameters=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Data(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
is_exogenous: bool


@dataclass(frozen=True)
class DataInfo(Info[Data]):
def __init__(self, data: list[Data]):
super().__init__(items=tuple(data), key_field="name")

@property
def needs_exogenous_data(self) -> bool:
return any(d.is_exogenous for d in self.items)

def __str__(self) -> str:
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"

def add(self, data: Data) -> "DataInfo":
# return a new DataInfo with data appended
return DataInfo(data=[*list(self.items), data])

def merge(self, other: "DataInfo") -> "DataInfo":
"""Combine data from two DataInfo objects."""
if not isinstance(other, DataInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate data names found: {overlapping}")

return DataInfo(data=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Coord(Property):
dimension: str
labels: tuple[str, ...]


@dataclass(frozen=True)
class CoordInfo(Info[Coord]):
def __init__(self, coords: list[Coord]):
super().__init__(items=tuple(coords), key_field="dimension")

def __str__(self) -> str:
base = "coordinates:"
for coord in self.items:
coord_str = str(coord)
indented = "\n".join(" " + line for line in coord_str.splitlines())
base += "\n" + indented + "\n"
return base

@classmethod
def default_coords_from_model(
cls, model: PyMCStateSpace
) -> (
Self
): # TODO: Need to figure out how to include Component type was causing circular import issues
states = tuple(model.state_names)
obs_states = tuple(model.observed_state_names)
shocks = tuple(model.shock_names)

dim_to_labels = (
(ALL_STATE_DIM, states),
(ALL_STATE_AUX_DIM, states),
(OBS_STATE_DIM, obs_states),
(OBS_STATE_AUX_DIM, obs_states),
(SHOCK_DIM, shocks),
(SHOCK_AUX_DIM, shocks),
)

coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels]
return cls(coords)

def to_dict(self):
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}

def add(self, coord: Coord) -> "CoordInfo":
# return a new CoordInfo with data appended
return CoordInfo(coords=[*list(self.items), coord])

def merge(self, other: "CoordInfo") -> "CoordInfo":
"""Combine data from two CoordInfo objects."""
if not isinstance(other, CoordInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate coord names found: {overlapping}")

return CoordInfo(coords=list(self.items) + list(other.items))


@dataclass(frozen=True)
class State(Property):
name: str
observed: bool
shared: bool


@dataclass(frozen=True)
class StateInfo(Info[State]):
def __init__(self, states: list[State]):
super().__init__(items=tuple(states), key_field="name")

def __str__(self) -> str:
return (
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
)

@property
def observed_states(self) -> tuple[State, ...]:
return tuple(s for s in self.items if s.observed)

def add(self, state: State) -> "StateInfo":
# return a new StateInfo with state appended
return StateInfo(states=[*list(self.items), state])

def merge(self, other: "StateInfo") -> "StateInfo":
"""Combine states from two StateInfo objects."""
if not isinstance(other, StateInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate state names found: {overlapping}")

return StateInfo(states=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Shock(Property):
name: str


@dataclass(frozen=True)
class ShockInfo(Info[Shock]):
def __init__(self, shocks: list[Shock]):
super().__init__(items=tuple(shocks), key_field="name")

def add(self, shock: Shock) -> "ShockInfo":
# return a new ShockInfo with shock appended
return ShockInfo(shocks=[*list(self.items), shock])

def merge(self, other: "ShockInfo") -> "ShockInfo":
"""Combine shocks from two ShockInfo objects."""
if not isinstance(other, ShockInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate shock names found: {overlapping}")

return ShockInfo(shocks=list(self.items) + list(other.items))
4 changes: 4 additions & 0 deletions pymc_extras/statespace/models/structural/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent
from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError
from pymc_extras.statespace.models.structural.components.regression import RegressionComponent
from pymc_extras.statespace.models.structural.components.regression_dataclass import (
RegressionComponent as RegressionComponentDataClass,
)
from pymc_extras.statespace.models.structural.components.seasonality import (
FrequencySeasonality,
TimeSeasonality,
Expand All @@ -17,5 +20,6 @@
"LevelTrendComponent",
"MeasurementError",
"RegressionComponent",
"RegressionComponentDataClass",
"TimeSeasonality",
]
Loading
Loading