@@ -729,7 +729,7 @@ def test_linear_sliced_sphere_backend_type_devices(nx):
729729 np .testing .assert_almost_equal (sw_np , nx .to_numpy (valb ))
730730
731731
732- def test_sliced_permutations (nx ):
732+ def test_sliced_permutations ():
733733 n = 4
734734 n_proj = 10
735735 d = 2
@@ -738,15 +738,7 @@ def test_sliced_permutations(nx):
738738 x = rng .randn (n , 2 )
739739 y = rng .randn (n , 2 )
740740
741- x_b , y_b = nx .from_numpy (x , y )
742741 thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 ).T
743- thetas_b = nx .from_numpy (thetas )
744-
745- plan , _ = ot .sliced .sliced_plans (x , y , thetas = thetas , dense = True )
746- plan_b , _ , _ = ot .sliced .sliced_plans (
747- x_b , y_b , thetas = thetas_b , log = True , dense = True , backend = nx
748- )
749- np .testing .assert_almost_equal (plan , nx .to_numpy (plan_b ))
750742
751743 # test without provided thetas
752744 _ , _ = ot .sliced .sliced_plans (x , y , n_proj = n_proj )
@@ -756,7 +748,7 @@ def test_sliced_permutations(nx):
756748 ot .sliced .sliced_plans (x [:, 1 :], y , thetas = thetas )
757749
758750
759- def test_sliced_plans (nx ):
751+ def test_sliced_plans ():
760752 x = [1 , 2 ]
761753 with pytest .raises (AssertionError ):
762754 ot .sliced .min_pivot_sliced (x , x , n_proj = 2 )
@@ -775,39 +767,23 @@ def test_sliced_plans(nx):
775767 b = rng .uniform (0 , 1 , m )
776768 b /= b .sum ()
777769
778- x_b , y_b = nx .from_numpy (x , y )
779- print (x_b )
780- t_X = torch .tensor (x_b )
781- t_Y = torch .tensor (y_b )
782770 thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 ).T
783- thetas_b = nx .from_numpy (thetas )
771+
772+ # test with a and b not uniform
773+ ot .sliced .sliced_plans (x , y , a , b , thetas = thetas , dense = True )
784774
785775 # test with the minkowski metric
786- ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "minkowski" )
776+ ot .sliced .sliced_plans (x , y , thetas = thetas , metric = "minkowski" )
787777
788778 # test with an unsupported metric
789779 with pytest .raises (ValueError ):
790- ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "mahalanobis" )
780+ ot .sliced .sliced_plans (x , y , thetas = thetas , metric = "mahalanobis" )
791781
792782 # test with a warm theta
793- ot .sliced .min_pivot_sliced (x , y , n_proj = 10 , warm_theta = thetas [- 1 ])
783+ ot .sliced .sliced_plans (x , y , n_proj = 10 , warm_theta = thetas [- 1 ])
794784
795- # test with a and b uniform
796- plan , _ = ot .sliced .sliced_plans (x , y , thetas = thetas , dense = True )
797- plan_b , _ , _ = ot .sliced .sliced_plans (
798- x_b , y_b , thetas = thetas_b , log = True , dense = True , backend = nx
799- )
800- np .testing .assert_almost_equal (plan , nx .to_numpy (plan_b ))
801785
802- # test with a and b not uniform
803- plan , _ = ot .sliced .sliced_plans (x , y , a , b , thetas = thetas , dense = True )
804- plan_b , _ , _ = ot .sliced .sliced_plans (
805- x_b , y_b , a , b , thetas = thetas_b , log = True , dense = True , backend = nx
806- )
807- np .testing .assert_almost_equal (plan , nx .to_numpy (plan_b ))
808-
809-
810- def test_min_pivot_sliced (nx ):
786+ def test_min_pivot_sliced ():
811787 x = [1 , 2 ]
812788 with pytest .raises (AssertionError ):
813789 ot .sliced .min_pivot_sliced (x , x , n_proj = 2 )
@@ -825,17 +801,13 @@ def test_min_pivot_sliced(nx):
825801 b = rng .uniform (0 , 1 , m )
826802 b /= b .sum ()
827803
828- x_b , y_b = nx .from_numpy (x , y )
829804 thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 ).T
830- thetas_b = nx .from_numpy (thetas )
831805
832- G , min_cost = ot .sliced .min_pivot_sliced (x , y , a , b , thetas = thetas , dense = True )
833- G_b , min_cost_b , _ = ot .sliced .min_pivot_sliced (
834- x_b , y_b , a , b , thetas = thetas_b , log = True , dense = True
835- )
806+ # identity of the indiscernibles
807+ _ , min_cost = ot .min_pivot_sliced (x , x , a , a , n_proj = 10 )
808+ np .testing .assert_almost_equal (min_cost , 0.0 )
836809
837- np .testing .assert_almost_equal (G , nx .to_numpy (G_b ))
838- np .testing .assert_almost_equal (min_cost , nx .to_numpy (min_cost_b ))
810+ _ , min_cost = ot .sliced .min_pivot_sliced (x , y , a , b , thetas = thetas , dense = True )
839811
840812 # result should be an upper-bound of W2 and relatively close
841813 w2 = ot .emd2 (a , b , ot .dist (x , y ))
@@ -849,8 +821,30 @@ def test_min_pivot_sliced(nx):
849821 with pytest .raises (AssertionError ):
850822 ot .sliced .min_pivot_sliced (x [:, 1 :], y , thetas = thetas )
851823
824+ # test the logs
825+ _ , min_cost , log = ot .sliced .min_pivot_sliced (
826+ x , y , a , b , thetas = thetas , dense = False , log = True
827+ )
828+ assert len (log ) == 5
829+ costs = log ["costs" ]
830+ assert len (costs ) == thetas .shape [0 ]
831+ assert len (log ["min_theta" ]) == d
832+ assert (log ["thetas" ] == thetas ).all ()
833+ for c in costs :
834+ assert c > 0
835+
836+ # test with the minkowski metric
837+ ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "minkowski" )
838+
839+ # test with an unsupported metric
840+ with pytest .raises (ValueError ):
841+ ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "mahalanobis" )
852842
853- def test_expected_sliced (nx ):
843+ # test with a warm theta
844+ ot .sliced .min_pivot_sliced (x , y , n_proj = 10 , warm_theta = thetas [- 1 ])
845+
846+
847+ def test_expected_sliced ():
854848 x = [1 , 2 ]
855849 with pytest .raises (AssertionError ):
856850 ot .sliced .min_pivot_sliced (x , x , n_proj = 2 )
@@ -868,9 +862,67 @@ def test_expected_sliced(nx):
868862 b = rng .uniform (0 , 1 , m )
869863 b /= b .sum ()
870864
871- x_b , y_b = nx .from_numpy (x , y )
872865 thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 ).T
873- thetas_b = nx .from_numpy (thetas )
866+
867+ _ , expected_cost = ot .sliced .expected_sliced (x , y , a , b , dense = True , thetas = thetas )
868+ # result should be a coarse upper-bound of W2
869+ w2 = ot .emd2 (a , b , ot .dist (x , y ))
870+ assert expected_cost >= w2
871+ assert expected_cost <= 3 * w2
872+
873+ # test without provided thetas
874+ ot .sliced .expected_sliced (x , y , n_proj = n_proj , log = True )
875+
876+ # test with invalid shapes
877+ with pytest .raises (AssertionError ):
878+ ot .sliced .min_pivot_sliced (x [:, 1 :], y , thetas = thetas )
879+
880+ # with a small temperature (i.e. large beta), the cost should be close
881+ # to min_pivot
882+ _ , expected_cost = ot .sliced .expected_sliced (
883+ x , y , a , b , thetas = thetas , dense = True , beta = 100.0
884+ )
885+ _ , min_cost = ot .sliced .min_pivot_sliced (x , y , a , b , thetas = thetas , dense = True )
886+ np .testing .assert_almost_equal (expected_cost , min_cost , decimal = 3 )
887+
888+ # test the logs
889+ _ , min_cost , log = ot .sliced .expected_sliced (
890+ x , y , a , b , thetas = thetas , dense = False , log = True
891+ )
892+ assert len (log ) == 4
893+ costs = log ["costs" ]
894+ assert len (costs ) == thetas .shape [0 ]
895+ assert len (log ["weights" ]) == thetas .shape [0 ]
896+ assert (log ["thetas" ] == thetas ).all ()
897+ for c in costs :
898+ assert c > 0
899+
900+ # test with the minkowski metric
901+ ot .sliced .expected_sliced (x , y , thetas = thetas , metric = "minkowski" )
902+
903+ # test with an unsupported metric
904+ with pytest .raises (ValueError ):
905+ ot .sliced .expected_sliced (x , y , thetas = thetas , metric = "mahalanobis" )
906+
907+
908+ def test_sliced_plans_backends (nx ):
909+ n = 10
910+ m = 24
911+ n_proj = 10
912+ d = 2
913+ rng = np .random .RandomState (0 )
914+
915+ x = rng .randn (n , 2 )
916+ y = rng .randn (m , 2 )
917+ a = rng .uniform (0 , 1 , n )
918+ a /= a .sum ()
919+ b = rng .uniform (0 , 1 , m )
920+ b /= b .sum ()
921+
922+ x_b , y_b , a_b , b_b = nx .from_numpy (x , y , a , b )
923+
924+ thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 , backend = nx ).T
925+ thetas_t = nx .to_numpy (thetas )
874926
875927 context = (
876928 nullcontext ()
@@ -879,32 +931,25 @@ def test_expected_sliced(nx):
879931 )
880932
881933 with context :
882- expected_plan , expected_cost = ot .sliced .expected_sliced (
883- x , y , a , b , dense = True , thetas = thetas
934+ _ , expected_cost_b = ot .sliced .expected_sliced (
935+ x_b , y_b , a_b , b_b , dense = True , thetas = thetas_t
884936 )
885- expected_plan_b , expected_cost_b , _ = ot .sliced .expected_sliced (
886- x_b , y_b , a , b , thetas = thetas_b , dense = True , log = True
937+ # result should be the same than numpy version
938+ _ , expected_cost = ot .sliced .expected_sliced (
939+ x , y , a , b , dense = True , thetas = thetas
887940 )
941+ np .testing .assert_almost_equal (expected_cost_b , expected_cost )
888942
889- np .testing .assert_almost_equal (expected_plan , nx .to_numpy (expected_plan_b ))
890- np .testing .assert_almost_equal (expected_cost , nx .to_numpy (expected_cost_b ))
891-
892- # result should be a coarse upper-bound of W2
893- w2 = ot .emd2 (a , b , ot .dist (x , y ))
894- assert expected_cost >= w2
895- assert expected_cost <= 3 * w2
896-
897- # test without provided thetas
898- ot .sliced .expected_sliced (x , y , n_proj = n_proj , log = True )
943+ # for min_pivot
944+ _ , min_cost_b = ot .sliced .min_pivot_sliced (
945+ x_b , y_b , a_b , b_b , dense = True , thetas = thetas_t
946+ )
947+ # result should be the same than numpy version
948+ _ , min_cost = ot .sliced .min_pivot_sliced (x , y , a , b , dense = True , thetas = thetas )
949+ np .testing .assert_almost_equal (min_cost_b , min_cost )
899950
900- # test with invalid shapes
901- with pytest .raises (AssertionError ):
902- ot .sliced .min_pivot_sliced (x [:, 1 :], y , thetas = thetas )
951+ # for sliced_plans
952+ thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 , backend = nx ).T
903953
904- # with a small temperature (i.e. large beta), the cost should be close
905- # to min_pivot
906- _ , expected_cost = ot .sliced .expected_sliced (
907- x , y , a , b , thetas = thetas , dense = True , beta = 100.0
908- )
909- _ , min_cost = ot .sliced .min_pivot_sliced (x , y , a , b , thetas = thetas , dense = True )
910- np .testing .assert_almost_equal (expected_cost , min_cost , decimal = 3 )
954+ # test with the minkowski metric
955+ ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "minkowski" )
0 commit comments