From a2318e4f8822dffe0aa8dd960d80ad5d703f8a14 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:51:38 +0200 Subject: [PATCH 1/9] Replace VarName usage with str in type hints and conversion --- pymc/model/core.py | 4 ++-- pymc/model_graph.py | 32 ++++++++++++++++---------------- pymc/util.py | 4 ++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 66e633e15..9a58f0e53 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1939,7 +1939,7 @@ def debug_parameters(rv): def to_graphviz( self, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", save: str | None = None, figsize: tuple[int, int] | None = None, @@ -2143,7 +2143,7 @@ def compile_fn( ) -def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]: +def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]: """Build a point. Uses same args as dict() does. diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 50fd5227d..499d19930 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -170,7 +170,7 @@ def default_data(var: TensorVariable) -> GraphvizNodeKwargs: } -def get_node_type(var_name: VarName, model) -> NodeType: +def get_node_type(var_name: str, model) -> NodeType: """Return the node type of the variable in the model.""" v = model[var_name] @@ -239,7 +239,7 @@ def __init__(self, model): self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() - def get_parent_names(self, var: TensorVariable) -> set[VarName]: + def get_parent_names(self, var: TensorVariable) -> set[str]: if var.owner is None: return set() @@ -258,12 +258,12 @@ def _expand(x): return x.owner.inputs return { - cast(VarName, ancestor.name) # type: ignore[union-attr] + cast(str, ancestor.name) # type: ignore[union-attr] for ancestor in walk(nodes=var.owner.inputs, expand=_expand) if ancestor in named_vars } - def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: + def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]: if var_names is None: return self._all_var_names @@ -294,12 +294,12 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa return [get_var_name(var) for var in selected_ancestors] def make_compute_graph( - self, var_names: Iterable[VarName] | None = None - ) -> dict[VarName, set[VarName]]: + self, var_names: Iterable[str] | None = None + ) -> dict[str, set[str]]: """Get map of var_name -> set(input var names) for the model.""" model = self.model named_vars = self._all_vars - input_map: dict[VarName, set[VarName]] = defaultdict(set) + input_map: dict[str, set[str]] = defaultdict(set) var_names_to_plot = self.vars_to_plot(var_names) for var_name in var_names_to_plot: @@ -316,7 +316,7 @@ def make_compute_graph( for ancestor in ancestors([obs_var]): if ancestor not in named_vars: continue - obs_name = cast(VarName, ancestor.name) + obs_name = cast(str, ancestor.name) input_map[var_name].discard(obs_name) input_map[obs_name].add(var_name) @@ -324,7 +324,7 @@ def make_compute_graph( def get_plates( self, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, ) -> list[Plate]: """Rough but surprisingly accurate plate detection. @@ -386,8 +386,8 @@ def get_plates( def edges( self, - var_names: Iterable[VarName] | None = None, - ) -> list[tuple[VarName, VarName]]: + var_names: Iterable[str] | None = None, + ) -> list[tuple[str, str]]: """Get edges between the variables in the model. Parameters @@ -402,7 +402,7 @@ def edges( """ return [ - (VarName(child.replace(":", "&")), VarName(parent.replace(":", "&"))) + (str(child.replace(":", "&")), str(parent.replace(":", "&"))) for child, parents in self.make_compute_graph(var_names=var_names).items() for parent in parents ] @@ -419,7 +419,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]: def make_graph( name: str, plates: list[Plate], - edges: list[tuple[VarName, VarName]], + edges: list[tuple[str, str]], formatting: str = "plain", save=None, figsize=None, @@ -493,7 +493,7 @@ def make_graph( def make_networkx( name: str, plates: list[Plate], - edges: list[tuple[VarName, VarName]], + edges: list[tuple[str, str]], formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, @@ -563,7 +563,7 @@ def make_networkx( def model_to_networkx( model=None, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, include_dim_lengths: bool = True, @@ -657,7 +657,7 @@ def model_to_networkx( def model_to_graphviz( model=None, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", save: str | None = None, figsize: tuple[int, int] | None = None, diff --git a/pymc/util.py b/pymc/util.py index 3f108b8b0..d5d439d68 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -214,9 +214,9 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_var_name(var) -> VarName: +def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - return VarName(str(getattr(var, "name", var))) + return str(getattr(var, "name", var)) def get_transformed(z): From a94b1e6a1c6ee66379a4e6d13fba5fcd316f81e1 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:59:43 +0200 Subject: [PATCH 2/9] Remove VarName type and its imports --- pymc/model/core.py | 1 - pymc/model_graph.py | 2 +- pymc/util.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 9a58f0e53..9990a4a8e 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -65,7 +65,6 @@ ) from pymc.util import ( UNSET, - VarName, WithMemoization, _UnsetType, get_transformed_name, diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 499d19930..b259959df 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -26,7 +26,7 @@ from pytensor.tensor.variable import TensorVariable from pymc.model.core import modelcontext -from pymc.util import VarName, get_default_varnames, get_var_name +from pymc.util import get_default_varnames, get_var_name __all__ = ( "ModelGraph", diff --git a/pymc/util.py b/pymc/util.py index d5d439d68..251523627 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -31,7 +31,6 @@ from pymc.exceptions import BlockModelAccessError -VarName = NewType("VarName", str) class _UnsetType: From 86977db018a2b19ef8110436c0b589193f7990aa Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:00:02 +0200 Subject: [PATCH 3/9] Remove VarName in docstring --- pymc/model_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index b259959df..16d232011 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -334,7 +334,7 @@ def get_plates( Returns ------- dict - Maps plate labels to the set of ``VarName``s inside the plate. + Maps plate labels to the set of strings inside the plate. """ plates = defaultdict(set) From 981e20042892a7b54c43c469f98e10cf3c10aa40 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:03:42 +0200 Subject: [PATCH 4/9] Replace getattr with an explicit attribute reference Kept the default behavior. If name is None, then it will convert `var` itself to str. --- pymc/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index 251523627..e3ee9d444 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -215,7 +215,7 @@ def get_default_varnames(var_iterator, include_transformed): def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - return str(getattr(var, "name", var)) + return str(var.name if var.name is not None else var) def get_transformed(z): From 81f1611ec54c176989aea09d789ccc0924e524c3 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:09:40 +0200 Subject: [PATCH 5/9] Formatting --- pymc/model_graph.py | 4 +--- pymc/util.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 16d232011..d853dd3bb 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -293,9 +293,7 @@ def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]: # ordering of self._all_var_names is important return [get_var_name(var) for var in selected_ancestors] - def make_compute_graph( - self, var_names: Iterable[str] | None = None - ) -> dict[str, set[str]]: + def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]: """Get map of var_name -> set(input var names) for the model.""" model = self.model named_vars = self._all_vars diff --git a/pymc/util.py b/pymc/util.py index e3ee9d444..ef8a37c17 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -18,7 +18,7 @@ from collections import namedtuple from collections.abc import Sequence from copy import deepcopy -from typing import NewType, cast +from typing import cast import arviz import cloudpickle @@ -32,7 +32,6 @@ from pymc.exceptions import BlockModelAccessError - class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" From cd857e24dd7900af6be317eea556122bef292706 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:17:21 +0200 Subject: [PATCH 6/9] Revert "Replace getattr with an explicit attribute reference" This reverts commit 981e20042892a7b54c43c469f98e10cf3c10aa40. --- pymc/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index ef8a37c17..22920eda5 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -214,7 +214,7 @@ def get_default_varnames(var_iterator, include_transformed): def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - return str(var.name if var.name is not None else var) + return str(getattr(var, "name", var)) def get_transformed(z): From 5148aa393cbde15d163c0ae3d6ece423116a8447 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 17:57:08 +0200 Subject: [PATCH 7/9] Implement get_var_name without getattr --- pymc/util.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index 22920eda5..eeb523652 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -214,7 +214,12 @@ def get_default_varnames(var_iterator, include_transformed): def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - return str(getattr(var, "name", var)) + if isinstance(var, str): + return var + elif var.name is not None: + return str(var.name) + else: + return str(var) def get_transformed(z): From fb011bf033d7a5187ab910fdb4dd7928f2248a82 Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Thu, 17 Jul 2025 18:05:44 +0200 Subject: [PATCH 8/9] Align get_var_name with the current behavior if var.name is None, current get_var_name returns "None", not str(var) --- pymc/util.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index eeb523652..32d8d65e7 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -214,12 +214,7 @@ def get_default_varnames(var_iterator, include_transformed): def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - if isinstance(var, str): - return var - elif var.name is not None: - return str(var.name) - else: - return str(var) + return var if isinstance(var, str) else str(var.name) def get_transformed(z): From 83e45679024bac74540956a3771fea2371fd8b3e Mon Sep 17 00:00:00 2001 From: Omer Yuksel <4077805+osyuksel@users.noreply.github.com> Date: Fri, 25 Jul 2025 13:30:20 +0200 Subject: [PATCH 9/9] Prepare for merge --- pymc/model_graph.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index d853dd3bb..c23db65cf 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -21,11 +21,11 @@ from typing import Any, cast from pytensor import function -from pytensor.graph.basic import ancestors, walk +from pytensor.graph.basic import Variable, ancestors, walk from pytensor.tensor.shape import Shape -from pytensor.tensor.variable import TensorVariable from pymc.model.core import modelcontext +from pymc.pytensorf import _cheap_eval_mode from pymc.util import get_default_varnames, get_var_name __all__ = ( @@ -73,7 +73,7 @@ def create_plate_label_with_dim_length( def fast_eval(var): - return function([], var, mode="FAST_COMPILE")() + return function([], var, mode=_cheap_eval_mode)() class NodeType(str, Enum): @@ -88,7 +88,7 @@ class NodeType(str, Enum): @dataclass class NodeInfo: - var: TensorVariable + var: Variable node_type: NodeType def __hash__(self): @@ -108,10 +108,10 @@ def __eq__(self, other) -> bool: GraphvizNodeKwargs = dict[str, Any] -NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] +NodeFormatter = Callable[[Variable], GraphvizNodeKwargs] -def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: +def default_potential(var: Variable) -> GraphvizNodeKwargs: """Return default data for potential in the graph.""" return { "shape": "octagon", @@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: } -def random_variable_symbol(var: TensorVariable) -> str: +def random_variable_symbol(var: Variable) -> str: """Get the symbol of the random variable.""" - symbol = var.owner.op.__class__.__name__ + op = var.owner.op - if symbol.endswith("RV"): - symbol = symbol[:-2] + if name := getattr(op, "name", None): + symbol = name[0].upper() + name[1:] + else: + symbol = op.__class__.__name__.removesuffix("RV") return symbol -def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs: +def default_free_rv(var: Variable) -> GraphvizNodeKwargs: """Return default data for free RV in the graph.""" symbol = random_variable_symbol(var) @@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs: } -def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: +def default_observed_rv(var: Variable) -> GraphvizNodeKwargs: """Return default data for observed RV in the graph.""" symbol = random_variable_symbol(var) @@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: } -def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: +def default_deterministic(var: Variable) -> GraphvizNodeKwargs: """Return default data for the deterministic in the graph.""" return { "shape": "box", @@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: } -def default_data(var: TensorVariable) -> GraphvizNodeKwargs: +def default_data(var: Variable) -> GraphvizNodeKwargs: """Return default data for the data in the graph.""" return { "shape": "box", @@ -239,7 +241,7 @@ def __init__(self, model): self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() - def get_parent_names(self, var: TensorVariable) -> set[str]: + def get_parent_names(self, var: Variable) -> set[str]: if var.owner is None: return set() @@ -343,7 +345,7 @@ def get_plates( dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() } var_shapes: dict[str, tuple[int, ...]] = { - var_name: tuple(fast_eval(self.model[var_name].shape)) + var_name: tuple(map(int, fast_eval(self.model[var_name].shape))) for var_name in self.vars_to_plot(var_names) }