Skip to content

Commit f475bc9

Browse files
author
Larry Dong
committed
Replace use of broadcastable with shape in MixtureRV
1 parent 74a3649 commit f475bc9

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

aeppl/mixture.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,17 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
180180
class MixtureRV(Op):
181181
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
182182

183-
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
183+
__props__ = ("indices_end_idx", "out_dtype", "out_shape")
184184

185-
def __init__(self, indices_end_idx, out_dtype, out_broadcastable):
185+
def __init__(self, indices_end_idx, out_dtype, out_shape):
186186
super().__init__()
187187
self.indices_end_idx = indices_end_idx
188188
self.out_dtype = out_dtype
189-
self.out_broadcastable = out_broadcastable
189+
self.out_shape = out_shape
190190

191191
def make_node(self, *inputs):
192192
return Apply(
193-
self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()]
193+
self, list(inputs), [TensorType(self.out_dtype, self.out_shape)()]
194194
)
195195

196196
def perform(self, node, inputs, outputs):
@@ -285,7 +285,7 @@ def mixture_replace(fgraph, node):
285285
mix_op = MixtureRV(
286286
1 + len(mixing_indices),
287287
old_mixture_rv.dtype,
288-
old_mixture_rv.broadcastable,
288+
old_mixture_rv.type.shape,
289289
)
290290
new_node = mix_op.make_node(*([join_axis] + mixing_indices + mixture_rvs))
291291

@@ -337,7 +337,7 @@ def switch_mixture_replace(fgraph, node):
337337
mix_op = MixtureRV(
338338
2,
339339
old_mixture_rv.dtype,
340-
old_mixture_rv.broadcastable,
340+
old_mixture_rv.type.shape,
341341
)
342342
if node.inputs[0].ndim == 0:
343343
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
@@ -397,7 +397,7 @@ def ifelse_mixture_replace(fgraph, node):
397397
mix_op = MixtureRV(
398398
2,
399399
old_mixture_rv.dtype,
400-
old_mixture_rv.broadcastable,
400+
old_mixture_rv.type.shape,
401401
)
402402

403403
if node.inputs[0].ndim == 0:

0 commit comments

Comments
 (0)