@@ -232,25 +232,6 @@ def test_hetero_mixture_binomial(p_val, size):
232
232
(),
233
233
0 ,
234
234
),
235
- (
236
- (
237
- np .array (0 , dtype = aesara .config .floatX ),
238
- np .array (1 , dtype = aesara .config .floatX ),
239
- ),
240
- (
241
- np .array (0.5 , dtype = aesara .config .floatX ),
242
- np .array (0.5 , dtype = aesara .config .floatX ),
243
- ),
244
- (
245
- np .array (100 , dtype = aesara .config .floatX ),
246
- np .array (1 , dtype = aesara .config .floatX ),
247
- ),
248
- np .array ([0.1 , 0.5 , 0.4 ], dtype = aesara .config .floatX ),
249
- (),
250
- (),
251
- (),
252
- 0 ,
253
- ),
254
235
(
255
236
(
256
237
np .array (0 , dtype = aesara .config .floatX ),
@@ -684,14 +665,118 @@ def test_mixture_with_DiracDelta():
684
665
assert M_rv in logp_res
685
666
686
667
687
- @pytest .mark .parametrize ("op" , [at .switch , ifelse ])
688
- def test_switch_ifelse_mixture (op ):
668
+ @pytest .mark .parametrize (
669
+ "op, X_args, Y_args, p_val, comp_size, idx_size" ,
670
+ [
671
+ [op ] + list (test_args )
672
+ for op in [at .switch , ifelse ]
673
+ for test_args in [
674
+ (
675
+ (
676
+ np .array (- 10 , dtype = aesara .config .floatX ),
677
+ np .array (0.1 , dtype = aesara .config .floatX ),
678
+ ),
679
+ (
680
+ np .array (10 , dtype = aesara .config .floatX ),
681
+ np .array (0.1 , dtype = aesara .config .floatX ),
682
+ ),
683
+ np .array (0.5 , dtype = aesara .config .floatX ),
684
+ (),
685
+ (),
686
+ ),
687
+ (
688
+ (
689
+ np .array (- 10 , dtype = aesara .config .floatX ),
690
+ np .array (0.1 , dtype = aesara .config .floatX ),
691
+ ),
692
+ (
693
+ np .array (10 , dtype = aesara .config .floatX ),
694
+ np .array (0.1 , dtype = aesara .config .floatX ),
695
+ ),
696
+ np .array (0.5 , dtype = aesara .config .floatX ),
697
+ (),
698
+ (6 ,),
699
+ ),
700
+ (
701
+ (
702
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
703
+ np .array (0.1 , dtype = aesara .config .floatX ),
704
+ ),
705
+ (
706
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
707
+ np .array (0.1 , dtype = aesara .config .floatX ),
708
+ ),
709
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
710
+ (2 ,),
711
+ (2 ,),
712
+ ),
713
+ (
714
+ (
715
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
716
+ np .array (0.1 , dtype = aesara .config .floatX ),
717
+ ),
718
+ (
719
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
720
+ np .array (0.1 , dtype = aesara .config .floatX ),
721
+ ),
722
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
723
+ None ,
724
+ None ,
725
+ ),
726
+ (
727
+ (
728
+ np .array (- 10 , dtype = aesara .config .floatX ),
729
+ np .array (0.1 , dtype = aesara .config .floatX ),
730
+ ),
731
+ (
732
+ np .array (10 , dtype = aesara .config .floatX ),
733
+ np .array (0.1 , dtype = aesara .config .floatX ),
734
+ ),
735
+ np .array (0.5 , dtype = aesara .config .floatX ),
736
+ (2 , 3 ),
737
+ (2 , 3 ),
738
+ ),
739
+ (
740
+ (
741
+ np .array (10 , dtype = aesara .config .floatX ),
742
+ np .array (0.1 , dtype = aesara .config .floatX ),
743
+ ),
744
+ (
745
+ np .array (- 10 , dtype = aesara .config .floatX ),
746
+ np .array (0.1 , dtype = aesara .config .floatX ),
747
+ ),
748
+ np .array (0.5 , dtype = aesara .config .floatX ),
749
+ (2 , 3 ),
750
+ (),
751
+ ),
752
+ (
753
+ (
754
+ np .array (10 , dtype = aesara .config .floatX ),
755
+ np .array (0.1 , dtype = aesara .config .floatX ),
756
+ ),
757
+ (
758
+ np .array (- 10 , dtype = aesara .config .floatX ),
759
+ np .array (0.1 , dtype = aesara .config .floatX ),
760
+ ),
761
+ np .array (0.5 , dtype = aesara .config .floatX ),
762
+ (3 ,),
763
+ (3 ,),
764
+ ),
765
+ ]
766
+ if not ((test_args [- 1 ] is None or len (test_args [- 1 ]) > 0 ) and op == ifelse )
767
+ ],
768
+ )
769
+ def test_switch_ifelse_mixture (op , X_args , Y_args , p_val , comp_size , idx_size ):
770
+ """
771
+ The argument size is both the input to srng.normal and the expected
772
+ size of the mixture RV Z1_rv
773
+ """
689
774
srng = at .random .RandomStream (29833 )
690
775
691
- X_rv = srng .normal (- 10.0 , 0.1 , name = "X" )
692
- Y_rv = srng .normal (10.0 , 0.1 , name = "Y" )
776
+ X_rv = srng .normal (* X_args , size = comp_size , name = "X" )
777
+ Y_rv = srng .normal (* Y_args , size = comp_size , name = "Y" )
693
778
694
- I_rv = srng .bernoulli (0.5 , name = "I" )
779
+ I_rv = srng .bernoulli (p_val , size = idx_size , name = "I" )
695
780
i_vv = I_rv .clone ()
696
781
i_vv .name = "i"
697
782
0 commit comments