@@ -1180,15 +1180,19 @@ def _lkj_normalizing_constant(eta, n):
11801180# _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't
11811181# be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
11821182class _LKJCholeskyCovRV (SymbolicRandomVariable ):
1183- extended_signature = "[rng],(),(),(n)->[rng],(n)"
1183+ extended_signature = "[rng],[rng], (),(),(n)->[rng], [rng],(n)"
11841184 _print_name = ("_lkjcholeskycov" , "\\ operatorname{_lkjcholeskycov}" )
11851185
11861186 @classmethod
11871187 def rv_op (cls , n , eta , sd_dist , * , size = None ):
11881188 # We don't allow passing `rng` because we don't fully control the rng of the components!
11891189 n = pt .as_tensor (n , dtype = "int64" , ndim = 0 )
11901190 eta = pt .as_tensor_variable (eta , ndim = 0 )
1191- rng = pytensor .shared (np .random .default_rng ())
1191+
1192+ # LKJCorr requries 2 random number generators
1193+ outer_rng = pytensor .shared (np .random .default_rng ())
1194+ scan_rng = pytensor .shared (np .random .default_rng ())
1195+
11921196 size = normalize_size_param (size )
11931197
11941198 # We resize the sd_dist automatically so that it has (size x n) independent
@@ -1212,8 +1216,11 @@ def rv_op(cls, n, eta, sd_dist, *, size=None):
12121216 D = sd_dist .type (name = "D" ) # Make sd_dist opaque to OpFromGraph
12131217 size = D .shape [:- 1 ]
12141218
1215- next_rng , C = LKJCorrRV ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
1216- C *= D [..., :, None ] * D [..., None , :]
1219+ next_outer_rng , next_scan_rng , C = LKJCorrRV ._random_corr_matrix (
1220+ outer_rng = outer_rng , scan_rng = scan_rng , n = n , eta = eta , size = size
1221+ )
1222+ vec_diag = pt .vectorize (pt .diag , signature = "(n)->(n,n)" )
1223+ C = vec_diag (D ) @ C @ vec_diag (D )
12171224
12181225 tril_idx = pt .tril_indices (n , k = 0 )
12191226 samples = pt .linalg .cholesky (C )[..., tril_idx [0 ], tril_idx [1 ]]
@@ -1225,12 +1232,12 @@ def rv_op(cls, n, eta, sd_dist, *, size=None):
12251232 samples = pt .reshape (samples , (* size , dist_shape ))
12261233
12271234 return _LKJCholeskyCovRV (
1228- inputs = [rng , n , eta , D ],
1229- outputs = [next_rng , samples ],
1230- )(rng , n , eta , sd_dist )
1235+ inputs = [outer_rng , scan_rng , n , eta , D ],
1236+ outputs = [next_outer_rng , next_scan_rng , samples ],
1237+ )(outer_rng , scan_rng , n , eta , sd_dist )
12311238
12321239 def update (self , node ):
1233- return {node .inputs [0 ]: node .outputs [0 ]}
1240+ return {node .inputs [0 ]: node .outputs [0 ], node . inputs [ 1 ]: node . outputs [ 1 ] }
12341241
12351242
12361243class _LKJCholeskyCov (Distribution ):
@@ -1258,7 +1265,7 @@ def dist(cls, n, eta, sd_dist, **kwargs):
12581265
12591266@_change_dist_size .register (_LKJCholeskyCovRV )
12601267def change_LKJCholeksyCovRV_size (op , dist , new_size , expand = False ):
1261- n , eta , sd_dist = dist .owner .inputs [1 :]
1268+ n , eta , sd_dist = dist .owner .inputs [2 :]
12621269
12631270 if expand :
12641271 old_size = sd_dist .shape [:- 1 ]
@@ -1268,7 +1275,7 @@ def change_LKJCholeksyCovRV_size(op, dist, new_size, expand=False):
12681275
12691276
12701277@_support_point .register (_LKJCholeskyCovRV )
1271- def _LKJCholeksyCovRV_support_point (op , rv , rng , n , eta , sd_dist ):
1278+ def _LKJCholeksyCovRV_support_point (op , rv , outer_rng , scan_rng , n , eta , sd_dist ):
12721279 diag_idxs = (pt .cumsum (pt .arange (1 , n + 1 )) - 1 ).astype ("int32" )
12731280 support_point = pt .zeros_like (rv )
12741281 support_point = pt .set_subtensor (support_point [..., diag_idxs ], 1 )
@@ -1277,12 +1284,12 @@ def _LKJCholeksyCovRV_support_point(op, rv, rng, n, eta, sd_dist):
12771284
12781285@_default_transform .register (_LKJCholeskyCovRV )
12791286def _LKJCholeksyCovRV_default_transform (op , rv ):
1280- _ , n , _ , _ = rv .owner .inputs
1287+ _ , _ , n , _ , _ = rv .owner .inputs
12811288 return transforms .CholeskyCovPacked (n )
12821289
12831290
12841291@_logprob .register (_LKJCholeskyCovRV )
1285- def _LKJCholeksyCovRV_logp (op , values , rng , n , eta , sd_dist , ** kwargs ):
1292+ def _LKJCholeksyCovRV_logp (op , values , outer_rng , scan_rng , n , eta , sd_dist , ** kwargs ):
12861293 (value ,) = values
12871294
12881295 if value .ndim > 1 :
@@ -1499,10 +1506,10 @@ def helper_deterministics(cls, n, packed_chol):
14991506
15001507class LKJCorrRV (SymbolicRandomVariable ):
15011508 name = "lkjcorr"
1502- extended_signature = "[rng],[size],(),()->[rng],(n,n)"
1509+ extended_signature = "[rng],[rng],[ size],(),()->[rng], [rng],(n,n)"
15031510 _print_name = ("LKJCorrRV" , "\\ operatorname{LKJCorrRV}" )
15041511
1505- def make_node (self , rng , size , n , eta ):
1512+ def make_node (self , outer_rng , scan_rng , size , n , eta ):
15061513 n = pt .as_tensor_variable (n )
15071514 if not all (n .type .broadcastable ):
15081515 raise ValueError ("n must be a scalar." )
@@ -1511,59 +1518,81 @@ def make_node(self, rng, size, n, eta):
15111518 if not all (eta .type .broadcastable ):
15121519 raise ValueError ("eta must be a scalar." )
15131520
1514- return super ().make_node (rng , size , n , eta )
1521+ return super ().make_node (outer_rng , scan_rng , size , n , eta )
15151522
15161523 @classmethod
1517- def rv_op (cls , n : int , eta , * , rng = None , size = None ):
1518- # HACK: normalize_size_param doesn't handle size=() properly
1519- if not size :
1520- size = None
1521-
1524+ def rv_op (cls , n : int , eta , * , outer_rng = None , scan_rng = None , size = None ):
15221525 n = pt .as_tensor (n , ndim = 0 , dtype = int )
15231526 eta = pt .as_tensor (eta , ndim = 0 )
1524- rng = normalize_rng_param (rng )
1527+ outer_rng = normalize_rng_param (outer_rng )
1528+ scan_rng = normalize_rng_param (scan_rng )
15251529 size = normalize_size_param (size )
15261530
1527- next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
1531+ outer_rng_out , scan_rng_out , C = cls ._random_corr_matrix (
1532+ outer_rng = outer_rng , scan_rng = scan_rng , n = n , eta = eta , size = size
1533+ )
15281534
1529- return cls (inputs = [rng , size , n , eta ], outputs = [next_rng , C ])(rng , size , n , eta )
1535+ return cls (
1536+ inputs = [outer_rng , scan_rng , size , n , eta ], outputs = [outer_rng_out , scan_rng_out , C ]
1537+ )(outer_rng , scan_rng , size , n , eta )
15301538
15311539 @classmethod
15321540 def _random_corr_matrix (
1533- cls , rng : Variable , n : int , eta : TensorVariable , size : TensorVariable
1541+ cls ,
1542+ outer_rng : Variable ,
1543+ scan_rng : Variable ,
1544+ n : int ,
1545+ eta : TensorVariable ,
1546+ size : TensorVariable ,
15341547 ) -> tuple [Variable , TensorVariable ]:
1535- # original implementation in R see:
1536- # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1537- size = () if rv_size_is_none (size ) else size
1548+ size_is_none = rv_size_is_none (size )
1549+ size = () if size_is_none else size
15381550
1539- beta = eta - 1.0 + n / 2.0
1540- next_rng , beta_rvs = pt .random .beta (alpha = beta , beta = beta , size = size , rng = rng ).owner .outputs
1541- r12 = 2.0 * beta_rvs - 1.0
1551+ beta0 = eta - 1.0 + n / 2.0
1552+
1553+ outer_rng_out , y0 = pt .random .beta (
1554+ alpha = beta0 , beta = beta0 , size = size , rng = outer_rng
1555+ ).owner .outputs
15421556
1543- P = pt .full ((* size , n , n ), pt .eye (n ))
1544- P = P [..., 0 , 1 ].set (r12 )
1545- P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
1546- n = get_underlying_scalar_constant_value (n )
1557+ r12 = 2.0 * y0 - 1.0
15471558
1548- for mp1 in range (2 , n ):
1549- beta -= 0.5
1559+ P0 = pt .full ((* size , n , n ), pt .eye (n ))
1560+ P0 = P0 [..., 0 , 1 ].set (r12 )
1561+ P0 = P0 [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
15501562
1551- next_rng , y = pt .random .beta (
1552- alpha = mp1 / 2.0 , beta = beta , size = size , rng = next_rng
1563+ def step (mp1 , beta , P , prev_rng ):
1564+ beta_next = beta - 0.5
1565+
1566+ middle_rng , y = pt .random .beta (
1567+ alpha = mp1 / 2.0 , beta = beta , size = size , rng = prev_rng
15531568 ).owner .outputs
15541569
15551570 next_rng , z = pt .random .normal (
1556- loc = 0 , scale = 1 , size = (* size , mp1 ), rng = next_rng
1571+ loc = 0 , scale = 1 , size = (* size , mp1 ), rng = middle_rng
15571572 ).owner .outputs
15581573
15591574 ein_sig_z = "i, i->" if z .ndim == 1 else "...ij, ...ij->...i"
1560- z = z / pt .sqrt (pt .einsum (ein_sig_z , z , z .copy ()))[..., np .newaxis ]
1561- P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
1575+
1576+ z = z / pt .sqrt (pt .einsum (ein_sig_z , z , z .copy ()))[..., None ]
1577+ P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., None ]) * z )
15621578 P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
15631579
1580+ return (beta_next , P ), {prev_rng : next_rng }
1581+
1582+ (_ , P_seq ), updates = pytensor .scan (
1583+ fn = step ,
1584+ outputs_info = [beta0 , P0 ],
1585+ sequences = [pt .arange (2 , n )],
1586+ non_sequences = [scan_rng ],
1587+ strict = True ,
1588+ )
1589+
1590+ P = pytensor .ifelse (n < 3 , P0 , P_seq [- 1 ])
1591+
15641592 C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
1593+ (scan_rng_out ,) = tuple (updates .values ())
15651594
1566- return next_rng , C
1595+ return outer_rng_out , scan_rng_out , C
15671596
15681597
15691598class _LKJCorr (BoundedContinuous ):
@@ -1574,6 +1603,14 @@ class _LKJCorr(BoundedContinuous):
15741603 def dist (cls , n , eta , ** kwargs ):
15751604 n = pt .as_tensor_variable (n ).astype (int )
15761605 eta = pt .as_tensor_variable (eta )
1606+ rng = kwargs .pop ("rng" , None )
1607+
1608+ if isinstance (rng , Variable ):
1609+ rng = rng .get_value ()
1610+
1611+ kwargs ["scan_rng" ] = pytensor .shared (np .random .default_rng (rng ))
1612+ kwargs ["outer_rng" ] = pytensor .shared (np .random .default_rng (rng ))
1613+
15771614 return super ().dist ([n , eta ], ** kwargs )
15781615
15791616 @staticmethod
@@ -1619,7 +1656,7 @@ def logp(value: TensorVariable, n, eta):
16191656
16201657@_default_transform .register (_LKJCorr )
16211658def lkjcorr_default_transform (op , rv ):
1622- rng , shape , n , eta , * _ = rv . owner . inputs = rv .owner .inputs
1659+ rng , scan_rng , shape , n , eta , * _ = rv .owner .inputs
16231660 n = pt .get_scalar_constant_value (n ) # Safely extract scalar value without eval
16241661 return CholeskyCorrTransform (n = n , upper = False )
16251662
0 commit comments