@@ -1100,16 +1100,13 @@ def test_emd_sparse_backends(nx):
11001100
11011101 rng = np .random .RandomState (42 )
11021102
1103- # Create distributions
11041103 a = ot .utils .unif (n_source )
11051104 b = ot .utils .unif (n_target )
11061105
1107- # Create cost matrix
11081106 x_source = rng .randn (n_source , 2 )
11091107 x_target = rng .randn (n_target , 2 ) + 0.5
11101108 C = ot .dist (x_source , x_target )
11111109
1112- # Create sparse k-NN graph
11131110 rows = []
11141111 cols = []
11151112 data = []
@@ -1156,20 +1153,16 @@ def test_emd_sparse_backends(nx):
11561153 (data_aug , (rows_aug , cols_aug )), shape = (n_source , n_target )
11571154 )
11581155
1159- # Test with numpy weights (baseline)
11601156 _ , log_np = ot .emd (a , b , C_augmented , log = True )
11611157
1162- # Test with backend weights
11631158 ab , bb = nx .from_numpy (a , b )
11641159 _ , log_backend = ot .emd (ab , bb , C_augmented , log = True )
11651160
1166- # Compare costs
11671161 cost_np = log_np ["cost" ]
11681162 cost_backend = nx .to_numpy (log_backend ["cost" ])
11691163
11701164 np .testing .assert_allclose (cost_np , cost_backend , rtol = 1e-5 , atol = 1e-7 )
11711165
1172- # Check flow values match
11731166 np .testing .assert_allclose (
11741167 log_np ["flow_values" ], log_backend ["flow_values" ], rtol = 1e-5 , atol = 1e-7
11751168 )
@@ -1186,16 +1179,13 @@ def test_emd2_sparse_backends(nx):
11861179
11871180 rng = np .random .RandomState (42 )
11881181
1189- # Create distributions
11901182 a = ot .utils .unif (n_source )
11911183 b = ot .utils .unif (n_target )
11921184
1193- # Create cost matrix
11941185 x_source = rng .randn (n_source , 2 )
11951186 x_target = rng .randn (n_target , 2 ) + 0.5
11961187 C = ot .dist (x_source , x_target )
11971188
1198- # Create sparse k-NN graph
11991189 rows = []
12001190 cols = []
12011191 data = []
@@ -1242,14 +1232,11 @@ def test_emd2_sparse_backends(nx):
12421232 (data_aug , (rows_aug , cols_aug )), shape = (n_source , n_target )
12431233 )
12441234
1245- # Test with numpy weights (baseline)
12461235 cost_np = ot .emd2 (a , b , C_augmented )
12471236
1248- # Test with backend weights
12491237 ab , bb = nx .from_numpy (a , b )
12501238 cost_backend = ot .emd2 (ab , bb , C_augmented )
12511239
1252- # Compare costs
12531240 cost_backend_np = nx .to_numpy (cost_backend )
12541241
12551242 np .testing .assert_allclose (cost_np , cost_backend_np , rtol = 1e-5 , atol = 1e-7 )
0 commit comments