11from typing import TypeVar
22
33import numpy as np
4+ from numpy .random import Generator
45
56import pytensor
67from pytensor .graph .type import Type
78
89
9- T = TypeVar ("T" , np . random . RandomState , np . random . Generator )
10+ T = TypeVar ("T" )
1011
1112
1213gen_states_keys = {
2425
2526
2627class RandomType (Type [T ]):
27- r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
28-
29- @staticmethod
30- def may_share_memory (a : T , b : T ):
31- return a ._bit_generator is b ._bit_generator # type: ignore[attr-defined]
28+ r"""A Type wrapper for `numpy.random.Generator."""
3229
3330
34- class RandomGeneratorType (RandomType [np . random . Generator ]):
31+ class RandomGeneratorType (RandomType [Generator ]):
3532 r"""A Type wrapper for `numpy.random.Generator`.
3633
3734 The reason this exists (and `Generic` doesn't suffice) is that
@@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
4744 def __repr__ (self ):
4845 return "RandomGeneratorType"
4946
47+ @staticmethod
48+ def may_share_memory (a : Generator , b : Generator ):
49+ return a ._bit_generator is b ._bit_generator # type: ignore[attr-defined]
50+
5051 def filter (self , data , strict = False , allow_downcast = None ):
5152 """
5253 XXX: This doesn't convert `data` to the same type of underlying RNG type
@@ -58,7 +59,7 @@ def filter(self, data, strict=False, allow_downcast=None):
5859 `Type.filter`, we need to have it here to avoid surprising circular
5960 dependencies in sub-classes.
6061 """
61- if isinstance (data , np . random . Generator ):
62+ if isinstance (data , Generator ):
6263 return data
6364
6465 if not strict and isinstance (data , dict ):
0 commit comments