Skip to content

Commit b6aa902

Browse files
Merge branch 'aesara-devs:main' into ifelse-mixtures
2 parents 8c9c0f3 + 0959489 commit b6aa902

File tree

12 files changed

+64
-59
lines changed

12 files changed

+64
-59
lines changed

README.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ containing Aesara ``RandomVariable``\s:
2828
2929
from aeppl import joint_logprob, pprint
3030
31+
srng = at.random.RandomStream()
3132
3233
# A simple scale mixture model
33-
S_rv = at.random.invgamma(0.5, 0.5)
34-
Y_rv = at.random.normal(0.0, at.sqrt(S_rv))
34+
S_rv = srng.invgamma(0.5, 0.5)
35+
Y_rv = srng.normal(0.0, at.sqrt(S_rv))
3536
3637
# Compute the joint log-probability
3738
logprob, (y, s) = joint_logprob(Y_rv, S_rv)
@@ -94,8 +95,8 @@ Joint log-probabilities can be computed for some terms that are *derived* from
9495
.. code-block:: python
9596
9697
# Create a switching model from a Bernoulli distributed index
97-
Z_rv = at.random.normal([-100, 100], 1.0, name="Z")
98-
I_rv = at.random.bernoulli(0.5, name="I")
98+
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
99+
I_rv = srng.bernoulli(0.5, name="I")
99100
100101
M_rv = Z_rv[I_rv]
101102
M_rv.name = "M"

aeppl/joint_logprob.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def conditional_logprob(
3131
3232
import aesara.tensor as at
3333
34-
srng = at.random.RandomStream(0)
34+
srng = at.random.RandomStream()
35+
3536
sigma2_rv = srng.invgamma(0.5, 0.5)
3637
Y_rv = srng.normal(0, at.sqrt(sigma2_rv))
3738
@@ -267,7 +268,7 @@ def joint_logprob(
267268
268269
import aesara.tensor as at
269270
270-
srng = at.random.RandomStream(0)
271+
srng = at.random.RandomStream()
271272
sigma2_rv = srng.invgamma(0.5, 0.5)
272273
Y_rv = srng.normal(0, at.sqrt(sigma2_rv))
273274

aeppl/logprob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def categorical_logprob(op, values, *inputs, **kwargs):
569569
)
570570
)
571571
# FIXME: `take_along_axis` drops a broadcastable dimension
572-
# when `value.broadcastable == p.broadcastable == (True, True, False)`.
572+
# when `value.type.shape == p.type.shape == (1, 1, None)`.
573573
else:
574574
res = at.log(p[value])
575575

aeppl/mixture.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,17 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
180180
class MixtureRV(Op):
181181
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
182182

183-
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
183+
__props__ = ("indices_end_idx", "out_dtype", "out_shape")
184184

185-
def __init__(self, indices_end_idx, out_dtype, out_broadcastable):
185+
def __init__(self, indices_end_idx, out_dtype, out_shape):
186186
super().__init__()
187187
self.indices_end_idx = indices_end_idx
188188
self.out_dtype = out_dtype
189-
self.out_broadcastable = out_broadcastable
189+
self.out_shape = out_shape
190190

191191
def make_node(self, *inputs):
192192
return Apply(
193-
self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()]
193+
self, list(inputs), [TensorType(self.out_dtype, shape=self.out_shape)()]
194194
)
195195

196196
def perform(self, node, inputs, outputs):
@@ -284,8 +284,8 @@ def mixture_replace(fgraph, node):
284284
# Replace this sub-graph with a `MixtureRV`
285285
mix_op = MixtureRV(
286286
1 + len(mixing_indices),
287-
old_mixture_rv.dtype,
288-
old_mixture_rv.broadcastable,
287+
old_mixture_rv.type.dtype,
288+
old_mixture_rv.type.shape,
289289
)
290290
new_node = mix_op.make_node(*([join_axis] + mixing_indices + mixture_rvs))
291291

@@ -364,8 +364,8 @@ def ifelse_mixture_replace(fgraph, node):
364364
"""
365365
mix_op = MixtureRV(
366366
2,
367-
old_mixture_rv.dtype,
368-
old_mixture_rv.broadcastable,
367+
old_mixture_rv.type.dtype,
368+
old_mixture_rv.type.shape,
369369
)
370370

371371
if node.inputs[0].ndim == 0:

aeppl/printing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ class PreamblePPrinter(PPrinter):
417417
-------
418418
>>> import aesara.tensor as at
419419
>>> from aeppl.printing import pprint
420-
>>> X_rv = at.random.normal(at.scalar('\\mu'), at.scalar('\\sigma'), name='X')
420+
>>> srng = at.random.RandomStream()
421+
>>> X_rv = srng.normal(at.scalar('\\mu'), at.scalar('\\sigma'), name='X')
421422
>>> print(pprint(X_rv))
422423
\\mu in R
423424
\\sigma in R

docs/source/api/distributions.rst

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The :py:func:`aeppl.logprob.logprob` function can be called on any random variab
1111
import aesara.tensor as at
1212
from aeppl.logprob import _logprob
1313
14-
srng = at.random.RandomStream(0)
14+
srng = at.random.RandomStream()
1515
1616
mu = at.scalar("mu")
1717
sigma = at.scalar("sigma")
@@ -29,7 +29,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
2929
3030
import aesara.tensor as at
3131
32-
srng = at.random.RandomStream(0)
32+
srng = at.random.RandomStream()
3333
3434
p = at.scalar("p")
3535
x_rv = snrg.bernoulli(p)
@@ -43,7 +43,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
4343
4444
import aesara.tensor as at
4545
46-
srng = at.random.RandomStream(0)
46+
srng = at.random.RandomStream()
4747
4848
a = at.scalar("a")
4949
b = at.scalar("b")
@@ -59,7 +59,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
5959
6060
import aesara.tensor as at
6161
62-
srng = at.random.RandomStream(0)
62+
srng = at.random.RandomStream()
6363
6464
n = at.iscalar("n")
6565
a = at.scalar("a")
@@ -76,7 +76,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
7676
7777
import aesara.tensor as at
7878
79-
srng = at.random.RandomStream(0)
79+
srng = at.random.RandomStream()
8080
8181
n = at.iscalar("n")
8282
p = at.scalar("p")
@@ -92,7 +92,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
9292
9393
import aesara.tensor as at
9494
95-
srng = at.random.RandomStream(0)
95+
srng = at.random.RandomStream()
9696
9797
loc = at.scalar("loc")
9898
scale = at.scalar("scale")
@@ -107,7 +107,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
107107
108108
import aesara.tensor as at
109109
110-
srng = at.random.RandomStream(0)
110+
srng = at.random.RandomStream()
111111
112112
p = at.vector("p")
113113
x_rv = snrg.categorical(p)
@@ -121,7 +121,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
121121
122122
import aesara.tensor as at
123123
124-
srng = at.random.RandomStream(0)
124+
srng = at.random.RandomStream()
125125
126126
df = at.scalar("df")
127127
x_rv = snrg.chisquare(df)
@@ -148,7 +148,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
148148
149149
import aesara.tensor as at
150150
151-
srng = at.random.RandomStream(0)
151+
srng = at.random.RandomStream()
152152
153153
alpha = at.vector("alpha")
154154
x_rv = snrg.dirichlet(alpha)
@@ -167,7 +167,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
167167
168168
import aesara.tensor as at
169169
170-
srng = at.random.RandomStream(0)
170+
srng = at.random.RandomStream()
171171
172172
beta = at.scalar("beta")
173173
x_rv = snrg.exponential(beta)
@@ -181,7 +181,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
181181
182182
import aesara.tensor as at
183183
184-
srng = at.random.RandomStream(0)
184+
srng = at.random.RandomStream()
185185
186186
alpha = at.scalar('alpha')
187187
beta = at.scalar('beta')
@@ -196,7 +196,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
196196
197197
import aesara.tensor as at
198198
199-
srng = at.random.RandomStream(0)
199+
srng = at.random.RandomStream()
200200
201201
p = at.scalar("p")
202202
x_rv = snrg.geometric(p)
@@ -210,7 +210,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
210210
211211
import aesara.tensor as at
212212
213-
srng = at.random.RandomStream(0)
213+
srng = at.random.RandomStream()
214214
215215
mu = at.scalar('mu')
216216
beta = at.scalar('beta')
@@ -225,7 +225,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
225225
226226
import aesara.tensor as at
227227
228-
srng = at.random.RandomStream(0)
228+
srng = at.random.RandomStream()
229229
230230
x0 = at.scalar('x0')
231231
gamma = at.scalar('gamma')
@@ -240,7 +240,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
240240
241241
import aesara.tensor as at
242242
243-
srng = at.random.RandomStream(0)
243+
srng = at.random.RandomStream()
244244
245245
mu = at.scalar('mu')
246246
sigma = at.scalar('sigma')
@@ -255,7 +255,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
255255
256256
import aesara.tensor as at
257257
258-
srng = at.random.RandomStream(0)
258+
srng = at.random.RandomStream()
259259
260260
ngood = at.scalar("ngood")
261261
nbad = at.scalar("nbad")
@@ -271,7 +271,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
271271
272272
import aesara.tensor as at
273273
274-
srng = at.random.RandomStream(0)
274+
srng = at.random.RandomStream()
275275
276276
alpha = at.scalar('alpha')
277277
beta = at.scalar('beta')
@@ -286,7 +286,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
286286
287287
import aesara.tensor as at
288288
289-
srng = at.random.RandomStream(0)
289+
srng = at.random.RandomStream()
290290
291291
mu = at.scalar("mu")
292292
lmbda = at.scalar("lambda")
@@ -301,7 +301,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
301301
302302
import aesara.tensor as at
303303
304-
srng = at.random.RandomStream(0)
304+
srng = at.random.RandomStream()
305305
306306
mu = at.scalar("mu")
307307
s = at.scalar("s")
@@ -316,7 +316,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
316316
317317
import aesara.tensor as at
318318
319-
srng = at.random.RandomStream(0)
319+
srng = at.random.RandomStream()
320320
321321
mu = at.scalar("mu")
322322
sigma = at.scalar("sigma")
@@ -331,7 +331,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
331331
332332
import aesara.tensor as at
333333
334-
srng = at.random.RandomStream(0)
334+
srng = at.random.RandomStream()
335335
336336
n = at.iscalar("n")
337337
p = at.vector("p")
@@ -346,7 +346,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
346346
347347
import aesara.tensor as at
348348
349-
srng = at.random.RandomStream(0)
349+
srng = at.random.RandomStream()
350350
351351
mu = at.vector('mu')
352352
Sigma = at.matrix('sigma')
@@ -362,7 +362,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
362362
363363
import aesara.tensor as at
364364
365-
srng = at.random.RandomStream(0)
365+
srng = at.random.RandomStream()
366366
367367
n = at.iscalar("n")
368368
p = at.scalar("p")
@@ -377,7 +377,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
377377
378378
import aesara.tensor as at
379379
380-
srng = at.random.RandomStream(0)
380+
srng = at.random.RandomStream()
381381
382382
mu = at.scalar('mu')
383383
sigma = at.scalar('sigma')
@@ -392,7 +392,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
392392
393393
import aesara.tensor as at
394394
395-
srng = at.random.RandomStream(0)
395+
srng = at.random.RandomStream()
396396
397397
b = at.scalar("b")
398398
scale = at.scalar("scale")
@@ -407,7 +407,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
407407
408408
import aesara.tensor as at
409409
410-
srng = at.random.RandomStream(0)
410+
srng = at.random.RandomStream()
411411
412412
lmbda = at.scalar("lambda")
413413
x_rv = snrg.poisson(lmbda)
@@ -421,7 +421,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
421421
422422
import aesara.tensor as at
423423
424-
srng = at.random.RandomStream(0)
424+
srng = at.random.RandomStream()
425425
426426
df = at.scalar('df')
427427
loc = at.scalar('loc')
@@ -437,7 +437,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
437437
438438
import aesara.tensor as at
439439
440-
srng = at.random.RandomStream(0)
440+
srng = at.random.RandomStream()
441441
442442
left = at.scalar('left')
443443
mode = at.scalar('mode')
@@ -453,7 +453,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
453453
454454
import aesara.tensor as at
455455
456-
srng = at.random.RandomStream(0)
456+
srng = at.random.RandomStream()
457457
458458
low = at.scalar('low')
459459
high = at.scalar('high')
@@ -468,7 +468,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
468468
469469
import aesara.tensor as at
470470
471-
srng = at.random.RandomStream(0)
471+
srng = at.random.RandomStream()
472472
473473
mu = at.scalar('mu')
474474
kappa = at.scalar('kappa')
@@ -483,7 +483,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
483483
484484
import aesara.tensor as at
485485
486-
srng = at.random.RandomStream(0)
486+
srng = at.random.RandomStream()
487487
488488
mu = at.scalar('mu')
489489
lmbda = at.scalar('lambda')
@@ -499,7 +499,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
499499
500500
import aesara.tensor as at
501501
502-
srng = at.random.RandomStream(0)
502+
srng = at.random.RandomStream()
503503
504504
k = at.scalar('k')
505505
x_rv = srng.weibull(k)

0 commit comments

Comments
 (0)