Skip to content

Commit 8e0488d

Browse files
committed
.WIP new explicit RNG API
1 parent 619fe66 commit 8e0488d

File tree

4 files changed

+89
-8
lines changed

4 files changed

+89
-8
lines changed

pytensor/tensor/random/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from pytensor.tensor.random.basic import *
55
from pytensor.tensor.random.op import default_rng
66
from pytensor.tensor.random.utils import RandomStream
7+
from pytensor.tensor.random.variable import rng

pytensor/tensor/random/op.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,16 @@ def infer_shape(self, fgraph, node, input_shapes):
314314

315315
return [None, list(shape)]
316316

317-
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
317+
def __call__(
318+
self,
319+
*args,
320+
size=None,
321+
name=None,
322+
rng=None,
323+
dtype=None,
324+
return_next_rng: bool = False,
325+
**kwargs,
326+
):
318327
if dtype is None:
319328
dtype = self.dtype
320329
if dtype == "floatX":
@@ -332,15 +341,26 @@ def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
332341
props["dtype"] = dtype
333342
new_op = type(self)(**props)
334343
return new_op.__call__(
335-
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
344+
*args,
345+
size=size,
346+
name=name,
347+
rng=rng,
348+
dtype=dtype,
349+
return_next_rng=return_next_rng,
350+
**kwargs,
336351
)
337352

338-
res = super().__call__(rng, size, *args, **kwargs)
339-
353+
node = self.make_node(rng, size, *args)
354+
outputs = node.outputs
340355
if name is not None:
341-
res.name = name
342-
343-
return res
356+
outputs[self.default_output].name = name
357+
if return_next_rng:
358+
return outputs
359+
else:
360+
out = outputs[self.default_output]
361+
if kwargs.get("return_list", False):
362+
return [out]
363+
return out
344364

345365
def make_node(self, rng, size, *dist_params):
346366
"""Create a random variable node.

pytensor/tensor/random/var.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared_constructor
66
from pytensor.tensor.random.type import random_generator_type
7+
from pytensor.tensor.random.variable import RandomGeneratorVariable
78

89

9-
class RandomGeneratorSharedVariable(SharedVariable):
10+
class RandomGeneratorSharedVariable(SharedVariable, RandomGeneratorVariable):
1011
def __str__(self):
1112
return self.name or f"RNG({self.container!r})"
1213

pytensor/tensor/random/variable.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import warnings
2+
from functools import wraps
3+
from typing import TypeAlias
4+
5+
from pytensor.graph.basic import OptionalApplyType, Variable
6+
from pytensor.tensor.random.basic import normal
7+
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
8+
from pytensor.tensor.variable import TensorVariable
9+
10+
11+
RNG_AND_DRAW: TypeAlias = tuple["RandomGeneratorVariable", TensorVariable]
12+
13+
14+
def warn_reuse(func):
15+
@wraps(func)
16+
def wrapper(self, *args, **kwargs):
17+
if hasattr(self.tag, "used"):
18+
warnings.warn(
19+
f"RandomGeneratorVariable {self} has already been used. "
20+
"You probably want to use the RandomGeneratorVariable that was returned then.",
21+
UserWarning,
22+
)
23+
self.tag.used = True
24+
return func(self, *args, **kwargs)
25+
26+
return wrapper
27+
28+
29+
class _random_generator_py_operators:
30+
# These can't work because Python requires native output types
31+
def __bool__(self):
32+
return True
33+
34+
@warn_reuse
35+
def normal(self, loc=0, scale=1, size=None) -> RNG_AND_DRAW:
36+
return normal(loc, scale, size=size, rng=self, return_next_rng=True)
37+
38+
39+
class RandomGeneratorVariable(
40+
_random_generator_py_operators,
41+
Variable[RandomGeneratorType, OptionalApplyType],
42+
):
43+
"""The Variable type used for random number generator states."""
44+
45+
46+
RandomGeneratorType.variable_type = RandomGeneratorVariable
47+
48+
49+
def rng(name=None) -> RandomGeneratorVariable:
50+
"""Create a new default random number generator variable.
51+
52+
Returns
53+
-------
54+
RandomGeneratorVariable
55+
A new random number generator variable initialized with the default
56+
numpy random generator.
57+
"""
58+
59+
return random_generator_type(name=name)

0 commit comments

Comments
 (0)