@@ -180,17 +180,17 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
180
180
class MixtureRV (Op ):
181
181
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
182
182
183
- __props__ = ("indices_end_idx" , "out_dtype" , "out_broadcastable " )
183
+ __props__ = ("indices_end_idx" , "out_dtype" , "out_shape " )
184
184
185
- def __init__ (self , indices_end_idx , out_dtype , out_broadcastable ):
185
+ def __init__ (self , indices_end_idx , out_dtype , out_shape ):
186
186
super ().__init__ ()
187
187
self .indices_end_idx = indices_end_idx
188
188
self .out_dtype = out_dtype
189
- self .out_broadcastable = out_broadcastable
189
+ self .out_shape = out_shape
190
190
191
191
def make_node (self , * inputs ):
192
192
return Apply (
193
- self , list (inputs ), [TensorType (self .out_dtype , self .out_broadcastable )()]
193
+ self , list (inputs ), [TensorType (self .out_dtype , self .out_shape )()]
194
194
)
195
195
196
196
def perform (self , node , inputs , outputs ):
@@ -285,7 +285,7 @@ def mixture_replace(fgraph, node):
285
285
mix_op = MixtureRV (
286
286
1 + len (mixing_indices ),
287
287
old_mixture_rv .dtype ,
288
- old_mixture_rv .broadcastable ,
288
+ old_mixture_rv .type . shape ,
289
289
)
290
290
new_node = mix_op .make_node (* ([join_axis ] + mixing_indices + mixture_rvs ))
291
291
@@ -337,7 +337,7 @@ def switch_mixture_replace(fgraph, node):
337
337
mix_op = MixtureRV (
338
338
2 ,
339
339
old_mixture_rv .dtype ,
340
- old_mixture_rv .broadcastable ,
340
+ old_mixture_rv .type . shape ,
341
341
)
342
342
if node .inputs [0 ].ndim == 0 :
343
343
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
@@ -397,7 +397,7 @@ def ifelse_mixture_replace(fgraph, node):
397
397
mix_op = MixtureRV (
398
398
2 ,
399
399
old_mixture_rv .dtype ,
400
- old_mixture_rv .broadcastable ,
400
+ old_mixture_rv .type . shape ,
401
401
)
402
402
403
403
if node .inputs [0 ].ndim == 0 :
0 commit comments