|
| 1 | +import warnings |
| 2 | +from functools import wraps |
| 3 | +from typing import TypeAlias |
| 4 | + |
| 5 | +from pytensor.graph.basic import OptionalApplyType, Variable |
| 6 | +from pytensor.tensor.random.basic import normal |
| 7 | +from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type |
| 8 | +from pytensor.tensor.variable import TensorVariable |
| 9 | + |
| 10 | + |
| 11 | +RNG_AND_DRAW: TypeAlias = tuple["RandomGeneratorVariable", TensorVariable] |
| 12 | + |
| 13 | + |
| 14 | +def warn_reuse(func): |
| 15 | + @wraps(func) |
| 16 | + def wrapper(self, *args, **kwargs): |
| 17 | + if hasattr(self.tag, "used"): |
| 18 | + warnings.warn( |
| 19 | + f"RandomGeneratorVariable {self} has already been used. " |
| 20 | + "You probably want to use the RandomGeneratorVariable that was returned then.", |
| 21 | + UserWarning, |
| 22 | + ) |
| 23 | + self.tag.used = True |
| 24 | + return func(self, *args, **kwargs) |
| 25 | + |
| 26 | + return wrapper |
| 27 | + |
| 28 | + |
| 29 | +class _random_generator_py_operators: |
| 30 | + # These can't work because Python requires native output types |
| 31 | + def __bool__(self): |
| 32 | + return True |
| 33 | + |
| 34 | + @warn_reuse |
| 35 | + def normal(self, loc=0, scale=1, size=None) -> RNG_AND_DRAW: |
| 36 | + return normal(loc, scale, size=size, rng=self, return_next_rng=True) |
| 37 | + |
| 38 | + |
| 39 | +class RandomGeneratorVariable( |
| 40 | + _random_generator_py_operators, |
| 41 | + Variable[RandomGeneratorType, OptionalApplyType], |
| 42 | +): |
| 43 | + """The Variable type used for random number generator states.""" |
| 44 | + |
| 45 | + |
| 46 | +RandomGeneratorType.variable_type = RandomGeneratorVariable |
| 47 | + |
| 48 | + |
| 49 | +def rng(name=None) -> RandomGeneratorVariable: |
| 50 | + """Create a new default random number generator variable. |
| 51 | +
|
| 52 | + Returns |
| 53 | + ------- |
| 54 | + RandomGeneratorVariable |
| 55 | + A new random number generator variable initialized with the default |
| 56 | + numpy random generator. |
| 57 | + """ |
| 58 | + |
| 59 | + return random_generator_type(name=name) |
0 commit comments