Skip to content

Replace VarName with built-in str #7855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 2 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
)
from pymc.util import (
UNSET,
VarName,
WithMemoization,
_UnsetType,
get_transformed_name,
Expand Down Expand Up @@ -1945,7 +1944,7 @@ def debug_parameters(rv):
def to_graphviz(
self,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
save: str | None = None,
figsize: tuple[int, int] | None = None,
Expand Down Expand Up @@ -2149,7 +2148,7 @@ def compile_fn(
)


def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]:
def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]:
"""Build a point.

Uses same args as dict() does.
Expand Down
36 changes: 17 additions & 19 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from pymc.model.core import modelcontext
from pymc.pytensorf import _cheap_eval_mode
from pymc.util import VarName, get_default_varnames, get_var_name
from pymc.util import get_default_varnames, get_var_name

__all__ = (
"ModelGraph",
Expand Down Expand Up @@ -172,7 +172,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
}


def get_node_type(var_name: VarName, model) -> NodeType:
def get_node_type(var_name: str, model) -> NodeType:
"""Return the node type of the variable in the model."""
v = model[var_name]

Expand Down Expand Up @@ -241,7 +241,7 @@ def __init__(self, model):
self._all_vars = {model[var_name] for var_name in self._all_var_names}
self.var_list = self.model.named_vars.values()

def get_parent_names(self, var: Variable) -> set[VarName]:
def get_parent_names(self, var: Variable) -> set[str]:
if var.owner is None:
return set()

Expand All @@ -260,12 +260,12 @@ def _expand(x):
return x.owner.inputs

return {
cast(VarName, ancestor.name) # type: ignore[union-attr]
cast(str, ancestor.name) # type: ignore[union-attr]
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
if ancestor in named_vars
}

def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]:
if var_names is None:
return self._all_var_names

Expand Down Expand Up @@ -295,13 +295,11 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
# ordering of self._all_var_names is important
return [get_var_name(var) for var in selected_ancestors]

def make_compute_graph(
self, var_names: Iterable[VarName] | None = None
) -> dict[VarName, set[VarName]]:
def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]:
"""Get map of var_name -> set(input var names) for the model."""
model = self.model
named_vars = self._all_vars
input_map: dict[VarName, set[VarName]] = defaultdict(set)
input_map: dict[str, set[str]] = defaultdict(set)

var_names_to_plot = self.vars_to_plot(var_names)
for var_name in var_names_to_plot:
Expand All @@ -318,15 +316,15 @@ def make_compute_graph(
for ancestor in ancestors([obs_var]):
if ancestor not in named_vars:
continue
obs_name = cast(VarName, ancestor.name)
obs_name = cast(str, ancestor.name)
input_map[var_name].discard(obs_name)
input_map[obs_name].add(var_name)

return input_map

def get_plates(
self,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
) -> list[Plate]:
"""Rough but surprisingly accurate plate detection.

Expand All @@ -336,7 +334,7 @@ def get_plates(
Returns
-------
dict
Maps plate labels to the set of ``VarName``s inside the plate.
Maps plate labels to the set of strings inside the plate.
"""
plates = defaultdict(set)

Expand Down Expand Up @@ -388,8 +386,8 @@ def get_plates(

def edges(
self,
var_names: Iterable[VarName] | None = None,
) -> list[tuple[VarName, VarName]]:
var_names: Iterable[str] | None = None,
) -> list[tuple[str, str]]:
"""Get edges between the variables in the model.

Parameters
Expand All @@ -404,7 +402,7 @@ def edges(

"""
return [
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
(str(child.replace(":", "&")), str(parent.replace(":", "&")))
for child, parents in self.make_compute_graph(var_names=var_names).items()
for parent in parents
]
Expand All @@ -421,7 +419,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
def make_graph(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
edges: list[tuple[str, str]],
formatting: str = "plain",
save=None,
figsize=None,
Expand Down Expand Up @@ -495,7 +493,7 @@ def make_graph(
def make_networkx(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
edges: list[tuple[str, str]],
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
Expand Down Expand Up @@ -565,7 +563,7 @@ def make_networkx(
def model_to_networkx(
model=None,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_dim_lengths: bool = True,
Expand Down Expand Up @@ -659,7 +657,7 @@ def model_to_networkx(
def model_to_graphviz(
model=None,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
save: str | None = None,
figsize: tuple[int, int] | None = None,
Expand Down
8 changes: 3 additions & 5 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import namedtuple
from collections.abc import Sequence
from copy import deepcopy
from typing import NewType, cast
from typing import cast

import arviz
import cloudpickle
Expand All @@ -31,8 +31,6 @@

from pymc.exceptions import BlockModelAccessError

VarName = NewType("VarName", str)


class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
Expand Down Expand Up @@ -214,9 +212,9 @@ def get_default_varnames(var_iterator, include_transformed):
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]


def get_var_name(var) -> VarName:
def get_var_name(var) -> str:
"""Get an appropriate, plain variable name for a variable."""
return VarName(str(getattr(var, "name", var)))
return var if isinstance(var, str) else str(var.name)


def get_transformed(z):
Expand Down