@@ -233,25 +233,6 @@ def test_hetero_mixture_binomial(p_val, size):
233
233
(),
234
234
0 ,
235
235
),
236
- (
237
- (
238
- np .array (0 , dtype = aesara .config .floatX ),
239
- np .array (1 , dtype = aesara .config .floatX ),
240
- ),
241
- (
242
- np .array (0.5 , dtype = aesara .config .floatX ),
243
- np .array (0.5 , dtype = aesara .config .floatX ),
244
- ),
245
- (
246
- np .array (100 , dtype = aesara .config .floatX ),
247
- np .array (1 , dtype = aesara .config .floatX ),
248
- ),
249
- np .array ([0.1 , 0.5 , 0.4 ], dtype = aesara .config .floatX ),
250
- (),
251
- (),
252
- (),
253
- 0 ,
254
- ),
255
236
(
256
237
(
257
238
np .array (0 , dtype = aesara .config .floatX ),
@@ -683,14 +664,118 @@ def test_mixture_with_DiracDelta():
683
664
assert M_rv in logp_res
684
665
685
666
686
- @pytest .mark .parametrize ("op" , [at .switch , ifelse ])
687
- def test_switch_ifelse_mixture (op ):
667
+ @pytest .mark .parametrize (
668
+ "op, X_args, Y_args, p_val, comp_size, idx_size" ,
669
+ [
670
+ [op ] + list (test_args )
671
+ for op in [at .switch , ifelse ]
672
+ for test_args in [
673
+ (
674
+ (
675
+ np .array (- 10 , dtype = aesara .config .floatX ),
676
+ np .array (0.1 , dtype = aesara .config .floatX ),
677
+ ),
678
+ (
679
+ np .array (10 , dtype = aesara .config .floatX ),
680
+ np .array (0.1 , dtype = aesara .config .floatX ),
681
+ ),
682
+ np .array (0.5 , dtype = aesara .config .floatX ),
683
+ (),
684
+ (),
685
+ ),
686
+ (
687
+ (
688
+ np .array (- 10 , dtype = aesara .config .floatX ),
689
+ np .array (0.1 , dtype = aesara .config .floatX ),
690
+ ),
691
+ (
692
+ np .array (10 , dtype = aesara .config .floatX ),
693
+ np .array (0.1 , dtype = aesara .config .floatX ),
694
+ ),
695
+ np .array (0.5 , dtype = aesara .config .floatX ),
696
+ (),
697
+ (6 ,),
698
+ ),
699
+ (
700
+ (
701
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
702
+ np .array (0.1 , dtype = aesara .config .floatX ),
703
+ ),
704
+ (
705
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
706
+ np .array (0.1 , dtype = aesara .config .floatX ),
707
+ ),
708
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
709
+ (2 ,),
710
+ (2 ,),
711
+ ),
712
+ (
713
+ (
714
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
715
+ np .array (0.1 , dtype = aesara .config .floatX ),
716
+ ),
717
+ (
718
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
719
+ np .array (0.1 , dtype = aesara .config .floatX ),
720
+ ),
721
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
722
+ None ,
723
+ None ,
724
+ ),
725
+ (
726
+ (
727
+ np .array (- 10 , dtype = aesara .config .floatX ),
728
+ np .array (0.1 , dtype = aesara .config .floatX ),
729
+ ),
730
+ (
731
+ np .array (10 , dtype = aesara .config .floatX ),
732
+ np .array (0.1 , dtype = aesara .config .floatX ),
733
+ ),
734
+ np .array (0.5 , dtype = aesara .config .floatX ),
735
+ (2 , 3 ),
736
+ (2 , 3 ),
737
+ ),
738
+ (
739
+ (
740
+ np .array (10 , dtype = aesara .config .floatX ),
741
+ np .array (0.1 , dtype = aesara .config .floatX ),
742
+ ),
743
+ (
744
+ np .array (- 10 , dtype = aesara .config .floatX ),
745
+ np .array (0.1 , dtype = aesara .config .floatX ),
746
+ ),
747
+ np .array (0.5 , dtype = aesara .config .floatX ),
748
+ (2 , 3 ),
749
+ (),
750
+ ),
751
+ (
752
+ (
753
+ np .array (10 , dtype = aesara .config .floatX ),
754
+ np .array (0.1 , dtype = aesara .config .floatX ),
755
+ ),
756
+ (
757
+ np .array (- 10 , dtype = aesara .config .floatX ),
758
+ np .array (0.1 , dtype = aesara .config .floatX ),
759
+ ),
760
+ np .array (0.5 , dtype = aesara .config .floatX ),
761
+ (3 ,),
762
+ (3 ,),
763
+ ),
764
+ ]
765
+ if not ((test_args [- 1 ] is None or len (test_args [- 1 ]) > 0 ) and op == ifelse )
766
+ ],
767
+ )
768
+ def test_switch_ifelse_mixture (op , X_args , Y_args , p_val , comp_size , idx_size ):
769
+ """
770
+ The argument size is both the input to srng.normal and the expected
771
+ size of the mixture RV Z1_rv
772
+ """
688
773
srng = at .random .RandomStream (29833 )
689
774
690
- X_rv = srng .normal (- 10.0 , 0.1 , name = "X" )
691
- Y_rv = srng .normal (10.0 , 0.1 , name = "Y" )
775
+ X_rv = srng .normal (* X_args , size = comp_size , name = "X" )
776
+ Y_rv = srng .normal (* Y_args , size = comp_size , name = "Y" )
692
777
693
- I_rv = srng .bernoulli (0.5 , name = "I" )
778
+ I_rv = srng .bernoulli (p_val , size = idx_size , name = "I" )
694
779
i_vv = I_rv .clone ()
695
780
i_vv .name = "i"
696
781
0 commit comments