Skip to content

Commit 85056c6

Browse files
committed
update tests and doc
1 parent e591737 commit 85056c6

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

ot/sliced.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,6 @@ def sliced_plans(
746746
Returned only if `log` is True.
747747
"""
748748

749-
X, Y = list_to_array(X, Y)
750749
nx = get_backend(X, Y) if backend is None else backend
751750
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
752751
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
@@ -903,9 +902,9 @@ def min_pivot_sliced(
903902
The first set of vectors.
904903
Y : array-like, shape (m, d)
905904
The second set of vectors.
906-
a : ndarray of float64, shape (ns,), optional
905+
a : ndarray of float64, shape (n,), optional
907906
Source histogram (default is uniform weight)
908-
b : ndarray of float64, shape (nt,), optional
907+
b : ndarray of float64, shape (m,), optional
909908
Target histogram (default is uniform weight)
910909
thetas : array-like, shape (n_proj, d), optional
911910
The projection directions. If None, random directions will be generated
@@ -962,7 +961,6 @@ def min_pivot_sliced(
962961
2.125
963962
"""
964963

965-
X, Y = list_to_array(X, Y)
966964
nx = get_backend(X, Y) if backend is None else backend
967965
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
968966
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
@@ -987,7 +985,7 @@ def min_pivot_sliced(
987985
log=True,
988986
backend=nx,
989987
)
990-
pos_min = np.argmin(costs)
988+
pos_min = nx.argmin(costs)
991989
cost = costs[pos_min]
992990
plan = G[pos_min]
993991

@@ -1050,23 +1048,33 @@ def expected_sliced(
10501048
10511049
Parameters
10521050
----------
1053-
X : torch.Tensor
1054-
A tensor of shape (n, d) representing the first set of vectors.
1055-
Y : torch.Tensor
1056-
A tensor of shape (m, d) representing the second set of vectors.
1051+
X : array-like, shape (n, d)
1052+
The first set of vectors.
1053+
Y : array-like, shape (m, d)
1054+
The second set of vectors.
1055+
a : ndarray of float64, shape (n,), optional
1056+
Source histogram (default is uniform weight)
1057+
b : ndarray of float64, shape (m,), optional
1058+
Target histogram (default is uniform weight)
10571059
thetas : torch.Tensor, optional
10581060
A tensor of shape (n_proj, d) representing the projection directions.
10591061
If None, random directions will be generated. Default is None.
1062+
metric: str, optional (default='sqeuclidean')
1063+
Metric to be used. Only works with either of the strings
1064+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
1065+
p: float, optional (default=2)
1066+
The p-norm to apply for if metric='minkowski'
10601067
n_proj : int, optional
10611068
The number of projection directions. Required if thetas is None.
1062-
order : int, optional
1063-
Power to elevate the norm. Default is 2.
10641069
dense: boolean, optional (default=True)
10651070
If True, returns :math:`\gamma` as a dense ndarray of shape (n, m).
10661071
Otherwise returns a sparse representation using scipy's `coo_matrix`
10671072
format.
10681073
log : bool, optional
10691074
If True, returns additional logging information. Default is False.
1075+
backend : ot.backend, optional
1076+
Backend to use for computations. If None, the backend is inferred from
1077+
the input arrays. Default is None.
10701078
beta : float, optional
10711079
Inverse-temperature parameter which weights each projection's
10721080
contribution to the expected plan. Default is 0 (uniform weighting).
@@ -1102,7 +1110,6 @@ def expected_sliced(
11021110
2.625
11031111
"""
11041112

1105-
X, Y = list_to_array(X, Y)
11061113
nx = get_backend(X, Y) if backend is None else backend
11071114

11081115
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"

test/test_sliced.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,10 @@ def test_sliced_plans(nx):
776776
b /= b.sum()
777777

778778
x_b, y_b = nx.from_numpy(x, y)
779+
print(x_b)
780+
t_X = torch.tensor(x_b)
781+
t_Y = torch.tensor(y_b)
779782
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T
780-
print("et là ???", thetas.shape)
781783
thetas_b = nx.from_numpy(thetas)
782784

783785
# test with the minkowski metric

0 commit comments

Comments
 (0)