Skip to content

Commit 58f7e11

Browse files
committed
update sliced plans with sparse matrix for tf compatibility
1 parent 2f4b675 commit 58f7e11

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

ot/sliced.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,12 @@ def sliced_plans(
756756
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
757757
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
758758

759+
assert metric in ("minkowski", "euclidean", "cityblock", "sqeuclidean"), (
760+
"Sliced plans work only with metrics "
761+
+ "from the following list: "
762+
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
763+
)
764+
759765
assert (
760766
X.shape[1] == Y.shape[1]
761767
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
@@ -776,7 +782,7 @@ def sliced_plans(
776782
do_draw_thetas = thetas is None
777783
if do_draw_thetas: # create thetas (n_proj, d)
778784
assert n_proj is not None, "n_proj must be specified if thetas is None"
779-
thetas = get_random_projections(d, n_proj, backend=nx).T
785+
thetas = get_random_projections(d, n_proj, backend=nx, type_as=X).T
780786

781787
if warm_theta is not None:
782788
thetas = nx.concatenate([thetas, warm_theta[:, None].T], axis=0)
@@ -787,8 +793,7 @@ def sliced_plans(
787793
X_theta = X @ thetas.T # shape (n, n_proj)
788794
Y_theta = Y @ thetas.T # shape (m, n_proj)
789795

790-
if is_perm:
791-
# we compute maps (permutations)
796+
if is_perm: # we compute maps (permutations)
792797
# sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj]
793798
sigma = nx.argsort(X_theta, axis=0) # (n, n_proj)
794799
tau = nx.argsort(Y_theta, axis=0) # (m, n_proj)
@@ -803,17 +808,12 @@ def sliced_plans(
803808
)
804809
for k in range(n_proj)
805810
]
806-
elif metric == "sqeuclidean":
811+
else: # metric = "sqeuclidean":
807812
costs = [
808813
nx.sum((nx.sum((X[sigma[:, k]] - Y[tau[:, k]]) ** 2, axis=1)) / n)
809814
for k in range(n_proj)
810815
]
811-
else:
812-
raise ValueError(
813-
"Sliced plans work only with metrics "
814-
+ "from the following list: "
815-
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
816-
)
816+
817817
a = nx.ones(n) / n
818818
plan = [
819819
nx.coo_matrix(a, sigma[:, k], tau[:, k], shape=(n, m), type_as=a)
@@ -825,6 +825,35 @@ def sliced_plans(
825825
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
826826
)
827827

828+
plan = plan.tocsr().tocoo() # especially for tensorflow compatibility
829+
830+
if str(nx) == "jax":
831+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
832+
if not dense:
833+
warnings.warn(
834+
"JAX does not support sparse matrices, converting" "to dense"
835+
)
836+
837+
costs = [
838+
nx.sum(
839+
(
840+
(
841+
nx.sum(
842+
nx.abs(
843+
X[np.nonzero(plan[k])[0]]
844+
- Y[np.nonzero(plan[k])[1]]
845+
)
846+
** p,
847+
axis=1,
848+
)
849+
)
850+
** (1 / p)
851+
)
852+
* plan[np.nonzero(plan[k])]
853+
)
854+
for k in range(n_proj)
855+
]
856+
828857
if metric in ("minkowski", "euclidean", "cityblock"):
829858
costs = [
830859
nx.sum(
@@ -836,26 +865,17 @@ def sliced_plans(
836865
)
837866
for k in range(n_proj)
838867
]
839-
elif metric == "sqeuclidean":
868+
else: # metric = "sqeuclidean"
840869
costs = [
841870
nx.sum(
842871
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
843872
* plan[k].data
844873
)
845874
for k in range(n_proj)
846875
]
847-
else:
848-
raise ValueError(
849-
"Sliced plans work only with metrics "
850-
+ "from the following list: "
851-
+ "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
852-
)
853876

854877
if dense:
855878
plan = [nx.todense(plan[k]) for k in range(n_proj)]
856-
elif str(nx) == "jax":
857-
warnings.warn("JAX does not support sparse matrices, converting to dense")
858-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
859879

860880
if log:
861881
log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas}

test/test_sliced.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,24 @@ def test_sliced_plans():
782782
# test with a warm theta
783783
ot.sliced.sliced_plans(x, y, n_proj=10, warm_theta=thetas[-1])
784784

785+
# test permutations
786+
n = 5
787+
m = 5
788+
n_proj = 10
789+
d = 2
790+
rng = np.random.RandomState(0)
791+
792+
x = rng.randn(n, 2)
793+
y = rng.randn(m, 2)
794+
795+
a = rng.uniform(0, 1, n)
796+
a /= a.sum()
797+
b = rng.uniform(0, 1, m)
798+
b /= b.sum()
799+
800+
# test with the minkowski metric
801+
ot.sliced.sliced_plans(x, y, metric="minkowski")
802+
785803

786804
def test_min_pivot_sliced():
787805
x = [1, 2]
@@ -924,7 +942,9 @@ def test_sliced_plans_backends(nx):
924942

925943
x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b)
926944

927-
thetas_b = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T
945+
thetas_b = ot.sliced.get_random_projections(
946+
d, n_proj, seed=0, backend=nx, type_as=x
947+
).T
928948
thetas = nx.to_numpy(thetas_b)
929949

930950
context = (

0 commit comments

Comments
 (0)