Skip to content

Commit fae9f02

Browse files
committed
fix : Quick test file fix
1 parent aa5f1c9 commit fae9f02

File tree

1 file changed

+0
-13
lines changed

1 file changed

+0
-13
lines changed

test/test_ot.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,16 +1100,13 @@ def test_emd_sparse_backends(nx):
11001100

11011101
rng = np.random.RandomState(42)
11021102

1103-
# Create distributions
11041103
a = ot.utils.unif(n_source)
11051104
b = ot.utils.unif(n_target)
11061105

1107-
# Create cost matrix
11081106
x_source = rng.randn(n_source, 2)
11091107
x_target = rng.randn(n_target, 2) + 0.5
11101108
C = ot.dist(x_source, x_target)
11111109

1112-
# Create sparse k-NN graph
11131110
rows = []
11141111
cols = []
11151112
data = []
@@ -1156,20 +1153,16 @@ def test_emd_sparse_backends(nx):
11561153
(data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)
11571154
)
11581155

1159-
# Test with numpy weights (baseline)
11601156
_, log_np = ot.emd(a, b, C_augmented, log=True)
11611157

1162-
# Test with backend weights
11631158
ab, bb = nx.from_numpy(a, b)
11641159
_, log_backend = ot.emd(ab, bb, C_augmented, log=True)
11651160

1166-
# Compare costs
11671161
cost_np = log_np["cost"]
11681162
cost_backend = nx.to_numpy(log_backend["cost"])
11691163

11701164
np.testing.assert_allclose(cost_np, cost_backend, rtol=1e-5, atol=1e-7)
11711165

1172-
# Check flow values match
11731166
np.testing.assert_allclose(
11741167
log_np["flow_values"], log_backend["flow_values"], rtol=1e-5, atol=1e-7
11751168
)
@@ -1186,16 +1179,13 @@ def test_emd2_sparse_backends(nx):
11861179

11871180
rng = np.random.RandomState(42)
11881181

1189-
# Create distributions
11901182
a = ot.utils.unif(n_source)
11911183
b = ot.utils.unif(n_target)
11921184

1193-
# Create cost matrix
11941185
x_source = rng.randn(n_source, 2)
11951186
x_target = rng.randn(n_target, 2) + 0.5
11961187
C = ot.dist(x_source, x_target)
11971188

1198-
# Create sparse k-NN graph
11991189
rows = []
12001190
cols = []
12011191
data = []
@@ -1242,14 +1232,11 @@ def test_emd2_sparse_backends(nx):
12421232
(data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)
12431233
)
12441234

1245-
# Test with numpy weights (baseline)
12461235
cost_np = ot.emd2(a, b, C_augmented)
12471236

1248-
# Test with backend weights
12491237
ab, bb = nx.from_numpy(a, b)
12501238
cost_backend = ot.emd2(ab, bb, C_augmented)
12511239

1252-
# Compare costs
12531240
cost_backend_np = nx.to_numpy(cost_backend)
12541241

12551242
np.testing.assert_allclose(cost_np, cost_backend_np, rtol=1e-5, atol=1e-7)

0 commit comments

Comments
 (0)