Skip to content

Commit 15e0d71

Browse files
committed
PR number, authot update, and small backend fix
1 parent 2f4b675 commit 15e0d71

File tree

6 files changed

+48
-2
lines changed

6 files changed

+48
-2
lines changed

.coveragerc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[run]
2+
omit =
3+
/tmp/*
4+
*/_remote_module_non_scriptable.py
5+
*/site-packages/*
6+
7+
[report]
8+
skip_covered = True

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#### New features
66

7-
- Added Sliced OT plans (PR #757)
7+
- Added Sliced OT plans (PR #767)
88

99
## 0.9.6.post1
1010

coverage_help.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Example:
2+
3+
coverage run -m pytest test/test_ot.py
4+
coverage html --rcfile=.coveragerc

debug.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# %%
2+
import numpy as np
3+
from ot.backend import get_backend
4+
import torch
5+
import ot
6+
from torch.optim import Adam
7+
8+
9+
# %%
10+
rng = np.random.RandomState(0)
11+
n = 10
12+
d = 2
13+
X = rng.randn(n, d)
14+
Y = rng.randn(n, d) + np.array([5.0, 0.0])[None, :]
15+
n_proj = 20
16+
P = ot.sliced.get_random_projections(d, n_proj)
17+
a = rng.uniform(0, 1, n)
18+
a /= a.sum()
19+
b = rng.uniform(0, 1, n)
20+
b /= b.sum()
21+
sw2 = ot.sliced.sliced_wasserstein_distance(X, Y, a=a, b=b, projections=P)
22+
23+
# %%
24+
nx = get_backend(torch.tensor([0.0]))
25+
X_t = nx.from_numpy(X)
26+
Y_t = nx.from_numpy(Y)
27+
a_t = nx.from_numpy(a)
28+
b_t = nx.from_numpy(b)
29+
P_t = nx.from_numpy(P)
30+
sw2_t = ot.sliced.sliced_wasserstein_distance(X_t, Y_t, a=a_t, b=b_t, projections=P_t)
31+
32+
# %%

ot/sliced.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Nicolas Courty <ncourty@irisa.fr>
77
# Rémi Flamary <remi.flamary@polytechnique.edu>
88
# Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
9+
# Laetitia Chapel <laetitia.chapel@irisa.fr>
910
#
1011
# License: MIT License
1112

@@ -776,7 +777,7 @@ def sliced_plans(
776777
do_draw_thetas = thetas is None
777778
if do_draw_thetas: # create thetas (n_proj, d)
778779
assert n_proj is not None, "n_proj must be specified if thetas is None"
779-
thetas = get_random_projections(d, n_proj, backend=nx).T
780+
thetas = get_random_projections(d, n_proj, backend=nx, type_as=X).T
780781

781782
if warm_theta is not None:
782783
thetas = nx.concatenate([thetas, warm_theta[:, None].T], axis=0)

test/test_sliced.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
44
# Nicolas Courty <ncourty@irisa.fr>
55
# Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
6+
# Laetitia Chapel <laetitia.chapel@irisa.fr>
67
#
78
# License: MIT License
89

0 commit comments

Comments
 (0)