Skip to content

Commit 83b653b

Browse files
committed
update tests and doc
1 parent 86749e0 commit 83b653b

File tree

3 files changed

+182
-86
lines changed

3 files changed

+182
-86
lines changed

examples/sliced-wasserstein/plot_sliced_plans.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import ot
2929
import matplotlib.pyplot as plt
3030
from ot.sliced import get_random_projections
31-
from ot.lp import wasserstein_1d
3231

3332

3433
seed = 0

ot/sliced.py

Lines changed: 102 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,9 @@ def sliced_plans(
691691
metric="sqeuclidean",
692692
p=2,
693693
thetas=None,
694-
warm_theta=False,
694+
warm_theta=None,
695695
n_proj=None,
696+
dense=False,
696697
log=False,
697698
backend=None,
698699
):
@@ -723,6 +724,9 @@ def sliced_plans(
723724
Default is None.
724725
warm_theta : array-like, shape (d,), optional
725726
A direction to add to the set of directions. Default is None.
727+
dense: bool, optional
728+
If True, returns dense matrices instead of sparse ones.
729+
Default is False.
726730
n_proj : int, optional
727731
The number of projection directions. Required if thetas is None.
728732
log : bool, optional
@@ -733,18 +737,19 @@ def sliced_plans(
733737
734738
Returns
735739
-------
736-
G, sigma, tau, costs
737-
G: ndarray, shape (ns, nt) or coo_matrix if dense is False
740+
plan : ndarray, shape (ns, nt) or coo_matrix if dense is False
738741
Optimal transportation matrix for the given parameters
739-
sigma : list of elements of array-like
740-
All the indices of X sorted along each projection.
741-
tau : list of elements of array-like
742-
All the indices of Y sorted along each projection.
742+
costs : list of float
743+
The cost associated to each projection.
743744
log_dict : dict, optional
744745
A dictionary containing intermediate computations for logging purposes.
745746
Returned only if `log` is True.
746747
"""
747748

749+
X, Y = list_to_array(X, Y)
750+
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
751+
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
752+
748753
assert (
749754
X.shape[1] == Y.shape[1]
750755
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
@@ -758,6 +763,11 @@ def sliced_plans(
758763
m = Y.shape[0]
759764
nx = get_backend(X, Y) if backend is None else backend
760765

766+
is_perm = False
767+
if n == m:
768+
if a is None or b is None or (a == b).all():
769+
is_perm = True
770+
761771
do_draw_thetas = thetas is None
762772
if do_draw_thetas: # create thetas (n_proj, d)
763773
assert n_proj is not None, "n_proj must be specified if thetas is None"
@@ -771,12 +781,11 @@ def sliced_plans(
771781
X_theta = X @ thetas.T # shape (n, n_proj)
772782
Y_theta = Y @ thetas.T # shape (m, n_proj)
773783

774-
if n == m and (a is None or b is None or (a == b).all()):
784+
if is_perm:
775785
# we compute maps (permutations)
776786
# sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj]
777787
sigma = nx.argsort(X_theta, axis=0) # (n, n_proj)
778788
tau = nx.argsort(Y_theta, axis=0) # (m, n_proj)
779-
780789
if metric in ("minkowski", "euclidean", "cityblock"):
781790
costs = [
782791
nx.sum(
@@ -799,37 +808,33 @@ def sliced_plans(
799808
+ "from the following list: "
800809
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
801810
)
802-
803-
G = [
804-
nx.coo_matrix(
805-
np.ones(n) / n,
806-
sigma[:, k],
807-
tau[:, k],
808-
shape=(n, m),
809-
type_as=X_theta,
810-
)
811+
plan = [
812+
nx.coo_matrix(np.ones(n) / n, sigma[:, k], tau[:, k], shape=(n, m))
811813
for k in range(n_proj)
812814
]
813815

814816
else: # we compute plans
815-
_, G = wasserstein_1d(
817+
_, plan = wasserstein_1d(
816818
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
817819
)
818820

819821
if metric in ("minkowski", "euclidean", "cityblock"):
820822
costs = [
821823
nx.sum(
822824
(
823-
(nx.sum(nx.abs(X[G[k].row] - Y[G[k].col]) ** p, axis=1))
825+
(nx.sum(nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1))
824826
** (1 / p)
825827
)
826-
* G[k].data
828+
* plan[k].data
827829
)
828830
for k in range(n_proj)
829831
]
830832
elif metric == "sqeuclidean":
831833
costs = [
832-
nx.sum((nx.sum((X[G[k].row] - Y[G[k].col]) ** 2, axis=1)) * G[k].data)
834+
nx.sum(
835+
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
836+
* plan[k].data
837+
)
833838
for k in range(n_proj)
834839
]
835840
else:
@@ -839,11 +844,17 @@ def sliced_plans(
839844
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
840845
)
841846

847+
if dense:
848+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
849+
elif str(nx) == "jax":
850+
warnings.warn("JAX does not support sparse matrices, converting to dense")
851+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
852+
842853
if log:
843854
log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas}
844-
return costs, G, log_dict
855+
return plan, costs, log_dict
845856
else:
846-
return costs, G
857+
return plan, costs
847858

848859

849860
def min_pivot_sliced(
@@ -863,11 +874,11 @@ def min_pivot_sliced(
863874
r"""
864875
Computes the cost and permutation associated to the min-Pivot Sliced
865876
Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given
866-
the supports `X` and `Y` of two discrete uniform measures with `n` atoms in
867-
dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj`
868-
different projections of the measures on random directions, and retains the
869-
permutation that yields the lowest cost between `X` and `Y` (compared
870-
in :math:`\mathbb{R}^d`).
877+
the supports `X` and `Y` of two discrete uniform measures with `n` and `m`
878+
atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through
879+
`n_proj` different projections of the measures on random directions, and
880+
retains the couplings that yields the lowest cost between `X` and `Y`
881+
(compared in :math:`\mathbb{R}^d`). When $n=m$, it gives
871882
872883
.. math::
873884
\mathrm{min\text{-}PS}_p^p(X, Y) \approx
@@ -888,7 +899,7 @@ def min_pivot_sliced(
888899
----------
889900
X : array-like, shape (n, d)
890901
The first set of vectors.
891-
Y : array-like, shape (n, d)
902+
Y : array-like, shape (m, d)
892903
The second set of vectors.
893904
a : ndarray of float64, shape (ns,), optional
894905
Source histogram (default is uniform weight)
@@ -918,10 +929,10 @@ def min_pivot_sliced(
918929
919930
Returns
920931
-------
921-
perm : array-like, shape (n,)
922-
The permutation that minimizes the cost.
923-
min_cost : float
924-
The minimum cost corresponding to the optimal permutation.
932+
plan : ndarray, shape (n, m) or coo_matrix if dense is False
933+
Optimal transportation matrix for the given parameters.
934+
cost : float
935+
The cost associated to the optimal permutation.
925936
log_dict : dict, optional
926937
A dictionary containing intermediate computations for logging purposes.
927938
Returned only if `log` is True.
@@ -935,16 +946,32 @@ def min_pivot_sliced(
935946
936947
.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
937948
Plans. arXiv preprint 2506.03661.
949+
950+
Examples
951+
--------
952+
>>> x=np.array([[3,3], [1,1]])
953+
>>> y=np.array([[2,2.5], [3,2]])
954+
>>> thetas=np.array([[1, 0], [0, 1]])
955+
>>> plan, cost = ot.expected_sliced(x, y, thetas)
956+
>>> plan
957+
[[0 0.5]
958+
[0.5 0]]
959+
>>> cost
960+
2.125
938961
"""
939962

963+
X, Y = list_to_array(X, Y)
964+
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
965+
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
966+
940967
assert (
941968
X.shape[1] == Y.shape[1]
942969
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
943970

944971
nx = get_backend(X, Y) if backend is None else backend
945972

946973
log_dict = {}
947-
costs, G, log_dict_plans = sliced_plans(
974+
G, costs, log_dict_plans = sliced_plans(
948975
X,
949976
Y,
950977
a,
@@ -1000,11 +1027,12 @@ def expected_sliced(
10001027
beta=0.0,
10011028
):
10021029
r"""
1003-
Computes the Expected Sliced cost and plan between two `(n, d)`
1004-
datasets `X` and `Y`. Given a set of `n_proj` projection directions,
1005-
the expected sliced plan is obtained by averaging the `n_proj` 1d optimal
1006-
transport plans between the projections of `X` and `Y` on each direction.
1007-
Expected Sliced was introduced in [84] and further studied in [83].
1030+
Computes the Expected Sliced cost and plan between two datasets `X` and
1031+
`Y` of shapes `(n, d)` and `(m, d)`. Given a set of `n_proj` projection
1032+
directions, the expected sliced plan is obtained by averaging the `n_proj`
1033+
1d optimal transport plans between the projections of `X` and `Y` on each
1034+
direction. Expected Sliced was introduced in [84] and further studied in
1035+
[83].
10081036
10091037
.. note::
10101038
The computation ignores potential ambiguities in the projections: if
@@ -1020,9 +1048,9 @@ def expected_sliced(
10201048
Parameters
10211049
----------
10221050
X : torch.Tensor
1023-
A tensor of shape (ns, d) representing the first set of vectors.
1051+
A tensor of shape (n, d) representing the first set of vectors.
10241052
Y : torch.Tensor
1025-
A tensor of shape (nt, d) representing the second set of vectors.
1053+
A tensor of shape (m, d) representing the second set of vectors.
10261054
thetas : torch.Tensor, optional
10271055
A tensor of shape (n_proj, d) representing the projection directions.
10281056
If None, random directions will be generated. Default is None.
@@ -1031,7 +1059,7 @@ def expected_sliced(
10311059
order : int, optional
10321060
Power to elevate the norm. Default is 2.
10331061
dense: boolean, optional (default=True)
1034-
If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt).
1062+
If True, returns :math:`\gamma` as a dense ndarray of shape (n, m).
10351063
Otherwise returns a sparse representation using scipy's `coo_matrix`
10361064
format.
10371065
log : bool, optional
@@ -1042,21 +1070,40 @@ def expected_sliced(
10421070
10431071
Returns
10441072
-------
1045-
plan : torch.Tensor
1046-
A tensor of shape (n_proj, n, n) representing the expected sliced plan.
1073+
plan : ndarray, shape (n, m) or coo_matrix if dense is False
1074+
Optimal transportation matrix for the given parameters.
1075+
cost : float
1076+
The cost associated to the optimal permutation.
10471077
log_dict : dict, optional
10481078
A dictionary containing intermediate computations for logging purposes.
10491079
Returned only if `log` is True.
10501080
10511081
References
10521082
----------
10531083
.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
1054-
Plans. arXiv preprint 2506.03661.
1055-
1084+
Plans. arXiv preprint 2506.03661.
10561085
.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi
1057-
A., Kolouri, S. (2024). Expected Sliced Transport Plans. International
1058-
Conference on Learning Representations.
1086+
A., Kolouri, S. (2024). Expected Sliced Transport Plans.
1087+
International Conference on Learning Representations.
1088+
1089+
Examples
1090+
--------
1091+
>>> x=np.array([[3,3], [1,1]])
1092+
>>> y=np.array([[2,2.5], [3,2]])
1093+
>>> thetas=np.array([[1, 0], [0, 1]])
1094+
>>> plan, cost = ot.expected_sliced(x, y, thetas)
1095+
>>> plan
1096+
[[0.25 0.25]
1097+
[0.25 0.25]]
1098+
>>> cost
1099+
2.625
10591100
"""
1101+
1102+
X, Y = list_to_array(X, Y)
1103+
1104+
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
1105+
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
1106+
10601107
assert (
10611108
X.shape[1] == Y.shape[1]
10621109
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
@@ -1069,11 +1116,11 @@ def expected_sliced(
10691116
"to array assignment."
10701117
)
10711118

1072-
ns = X.shape[0]
1073-
nt = Y.shape[0]
1119+
n = X.shape[0]
1120+
m = Y.shape[0]
10741121

10751122
log_dict = {}
1076-
costs, G, log_dict_plans = sliced_plans(
1123+
G, costs, log_dict_plans = sliced_plans(
10771124
X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, backend=nx
10781125
)
10791126
if log:
@@ -1087,31 +1134,22 @@ def expected_sliced(
10871134
else: # uniform weights
10881135
if n_proj is None:
10891136
n_proj = thetas.shape[0]
1090-
weights = nx.ones(n_proj, type_as=X) / n_proj
1137+
weights = nx.ones(n_proj) / n_proj
10911138

10921139
log_dict["weights"] = weights
10931140

10941141
weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))])
10951142
X_idx = nx.concatenate([G[i].row for i in range(len(G))])
10961143
Y_idx = nx.concatenate([G[i].col for i in range(len(G))])
1097-
plan = nx.coo_matrix(
1098-
weights,
1099-
X_idx,
1100-
Y_idx,
1101-
shape=(ns, nt),
1102-
type_as=X,
1103-
)
1144+
plan = nx.coo_matrix(weights, X_idx, Y_idx, shape=(n, m))
11041145

11051146
if beta == 0.0: # otherwise already computed above
11061147
cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum()
11071148

11081149
if dense:
11091150
plan = nx.todense(plan)
11101151
elif str(nx) == "jax":
1111-
warnings.warn(
1112-
"JAX does not support sparse matrices, converting to\
1113-
dense"
1114-
)
1152+
warnings.warn("JAX does not support sparse matrices, converting to dense")
11151153
plan = nx.todense(plan)
11161154

11171155
if log:

0 commit comments

Comments
 (0)