diff --git a/pymc/model/core.py b/pymc/model/core.py index 5ec5c0ec3..337827447 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -67,7 +67,6 @@ ) from pymc.util import ( UNSET, - VarName, WithMemoization, _UnsetType, get_transformed_name, @@ -1945,7 +1944,7 @@ def debug_parameters(rv): def to_graphviz( self, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", save: str | None = None, figsize: tuple[int, int] | None = None, @@ -2149,7 +2148,7 @@ def compile_fn( ) -def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]: +def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]: """Build a point. Uses same args as dict() does. diff --git a/pymc/model_graph.py b/pymc/model_graph.py index df228d638..c23db65cf 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -26,7 +26,7 @@ from pymc.model.core import modelcontext from pymc.pytensorf import _cheap_eval_mode -from pymc.util import VarName, get_default_varnames, get_var_name +from pymc.util import get_default_varnames, get_var_name __all__ = ( "ModelGraph", @@ -172,7 +172,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs: } -def get_node_type(var_name: VarName, model) -> NodeType: +def get_node_type(var_name: str, model) -> NodeType: """Return the node type of the variable in the model.""" v = model[var_name] @@ -241,7 +241,7 @@ def __init__(self, model): self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() - def get_parent_names(self, var: Variable) -> set[VarName]: + def get_parent_names(self, var: Variable) -> set[str]: if var.owner is None: return set() @@ -260,12 +260,12 @@ def _expand(x): return x.owner.inputs return { - cast(VarName, ancestor.name) # type: ignore[union-attr] + cast(str, ancestor.name) # type: ignore[union-attr] for ancestor in walk(nodes=var.owner.inputs, expand=_expand) if ancestor in named_vars } - def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: + def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]: if var_names is None: return self._all_var_names @@ -295,13 +295,11 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa # ordering of self._all_var_names is important return [get_var_name(var) for var in selected_ancestors] - def make_compute_graph( - self, var_names: Iterable[VarName] | None = None - ) -> dict[VarName, set[VarName]]: + def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]: """Get map of var_name -> set(input var names) for the model.""" model = self.model named_vars = self._all_vars - input_map: dict[VarName, set[VarName]] = defaultdict(set) + input_map: dict[str, set[str]] = defaultdict(set) var_names_to_plot = self.vars_to_plot(var_names) for var_name in var_names_to_plot: @@ -318,7 +316,7 @@ def make_compute_graph( for ancestor in ancestors([obs_var]): if ancestor not in named_vars: continue - obs_name = cast(VarName, ancestor.name) + obs_name = cast(str, ancestor.name) input_map[var_name].discard(obs_name) input_map[obs_name].add(var_name) @@ -326,7 +324,7 @@ def make_compute_graph( def get_plates( self, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, ) -> list[Plate]: """Rough but surprisingly accurate plate detection. @@ -336,7 +334,7 @@ def get_plates( Returns ------- dict - Maps plate labels to the set of ``VarName``s inside the plate. + Maps plate labels to the set of strings inside the plate. """ plates = defaultdict(set) @@ -388,8 +386,8 @@ def get_plates( def edges( self, - var_names: Iterable[VarName] | None = None, - ) -> list[tuple[VarName, VarName]]: + var_names: Iterable[str] | None = None, + ) -> list[tuple[str, str]]: """Get edges between the variables in the model. Parameters @@ -404,7 +402,7 @@ def edges( """ return [ - (VarName(child.replace(":", "&")), VarName(parent.replace(":", "&"))) + (str(child.replace(":", "&")), str(parent.replace(":", "&"))) for child, parents in self.make_compute_graph(var_names=var_names).items() for parent in parents ] @@ -421,7 +419,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]: def make_graph( name: str, plates: list[Plate], - edges: list[tuple[VarName, VarName]], + edges: list[tuple[str, str]], formatting: str = "plain", save=None, figsize=None, @@ -495,7 +493,7 @@ def make_graph( def make_networkx( name: str, plates: list[Plate], - edges: list[tuple[VarName, VarName]], + edges: list[tuple[str, str]], formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, @@ -565,7 +563,7 @@ def make_networkx( def model_to_networkx( model=None, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, include_dim_lengths: bool = True, @@ -659,7 +657,7 @@ def model_to_networkx( def model_to_graphviz( model=None, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", save: str | None = None, figsize: tuple[int, int] | None = None, diff --git a/pymc/util.py b/pymc/util.py index 3f108b8b0..32d8d65e7 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -18,7 +18,7 @@ from collections import namedtuple from collections.abc import Sequence from copy import deepcopy -from typing import NewType, cast +from typing import cast import arviz import cloudpickle @@ -31,8 +31,6 @@ from pymc.exceptions import BlockModelAccessError -VarName = NewType("VarName", str) - class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" @@ -214,9 +212,9 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_var_name(var) -> VarName: +def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - return VarName(str(getattr(var, "name", var))) + return var if isinstance(var, str) else str(var.name) def get_transformed(z):