@@ -691,8 +691,9 @@ def sliced_plans(
691691 metric = "sqeuclidean" ,
692692 p = 2 ,
693693 thetas = None ,
694- warm_theta = False ,
694+ warm_theta = None ,
695695 n_proj = None ,
696+ dense = False ,
696697 log = False ,
697698 backend = None ,
698699):
@@ -723,6 +724,9 @@ def sliced_plans(
723724 Default is None.
724725 warm_theta : array-like, shape (d,), optional
725726 A direction to add to the set of directions. Default is None.
727+ dense: bool, optional
728+ If True, returns dense matrices instead of sparse ones.
729+ Default is False.
726730 n_proj : int, optional
727731 The number of projection directions. Required if thetas is None.
728732 log : bool, optional
@@ -733,18 +737,19 @@ def sliced_plans(
733737
734738 Returns
735739 -------
736- G, sigma, tau, costs
737- G: ndarray, shape (ns, nt) or coo_matrix if dense is False
740+ plan : ndarray, shape (ns, nt) or coo_matrix if dense is False
738741 Optimal transportation matrix for the given parameters
739- sigma : list of elements of array-like
740- All the indices of X sorted along each projection.
741- tau : list of elements of array-like
742- All the indices of Y sorted along each projection.
742+ costs : list of float
743+ The cost associated to each projection.
743744 log_dict : dict, optional
744745 A dictionary containing intermediate computations for logging purposes.
745746 Returned only if `log` is True.
746747 """
747748
749+ X , Y = list_to_array (X , Y )
750+ assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
751+ assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
752+
748753 assert (
749754 X .shape [1 ] == Y .shape [1 ]
750755 ), f"X ({ X .shape } ) and Y ({ Y .shape } ) must have the same number of columns"
@@ -758,6 +763,11 @@ def sliced_plans(
758763 m = Y .shape [0 ]
759764 nx = get_backend (X , Y ) if backend is None else backend
760765
766+ is_perm = False
767+ if n == m :
768+ if a is None or b is None or (a == b ).all ():
769+ is_perm = True
770+
761771 do_draw_thetas = thetas is None
762772 if do_draw_thetas : # create thetas (n_proj, d)
763773 assert n_proj is not None , "n_proj must be specified if thetas is None"
@@ -771,12 +781,11 @@ def sliced_plans(
771781 X_theta = X @ thetas .T # shape (n, n_proj)
772782 Y_theta = Y @ thetas .T # shape (m, n_proj)
773783
774- if n == m and ( a is None or b is None or ( a == b ). all ()) :
784+ if is_perm :
775785 # we compute maps (permutations)
776786 # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj]
777787 sigma = nx .argsort (X_theta , axis = 0 ) # (n, n_proj)
778788 tau = nx .argsort (Y_theta , axis = 0 ) # (m, n_proj)
779-
780789 if metric in ("minkowski" , "euclidean" , "cityblock" ):
781790 costs = [
782791 nx .sum (
@@ -799,37 +808,33 @@ def sliced_plans(
799808 + "from the following list: "
800809 + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
801810 )
802-
803- G = [
804- nx .coo_matrix (
805- np .ones (n ) / n ,
806- sigma [:, k ],
807- tau [:, k ],
808- shape = (n , m ),
809- type_as = X_theta ,
810- )
811+ plan = [
812+ nx .coo_matrix (np .ones (n ) / n , sigma [:, k ], tau [:, k ], shape = (n , m ))
811813 for k in range (n_proj )
812814 ]
813815
814816 else : # we compute plans
815- _ , G = wasserstein_1d (
817+ _ , plan = wasserstein_1d (
816818 X_theta , Y_theta , a , b , p , require_sort = True , return_plan = True
817819 )
818820
819821 if metric in ("minkowski" , "euclidean" , "cityblock" ):
820822 costs = [
821823 nx .sum (
822824 (
823- (nx .sum (nx .abs (X [G [k ].row ] - Y [G [k ].col ]) ** p , axis = 1 ))
825+ (nx .sum (nx .abs (X [plan [k ].row ] - Y [plan [k ].col ]) ** p , axis = 1 ))
824826 ** (1 / p )
825827 )
826- * G [k ].data
828+ * plan [k ].data
827829 )
828830 for k in range (n_proj )
829831 ]
830832 elif metric == "sqeuclidean" :
831833 costs = [
832- nx .sum ((nx .sum ((X [G [k ].row ] - Y [G [k ].col ]) ** 2 , axis = 1 )) * G [k ].data )
834+ nx .sum (
835+ (nx .sum ((X [plan [k ].row ] - Y [plan [k ].col ]) ** 2 , axis = 1 ))
836+ * plan [k ].data
837+ )
833838 for k in range (n_proj )
834839 ]
835840 else :
@@ -839,11 +844,17 @@ def sliced_plans(
839844 + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
840845 )
841846
847+ if dense :
848+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
849+ elif str (nx ) == "jax" :
850+ warnings .warn ("JAX does not support sparse matrices, converting to dense" )
851+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
852+
842853 if log :
843854 log_dict = {"X_theta" : X_theta , "Y_theta" : Y_theta , "thetas" : thetas }
844- return costs , G , log_dict
855+ return plan , costs , log_dict
845856 else :
846- return costs , G
857+ return plan , costs
847858
848859
849860def min_pivot_sliced (
@@ -863,11 +874,11 @@ def min_pivot_sliced(
863874 r"""
864875 Computes the cost and permutation associated to the min-Pivot Sliced
865876 Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given
866- the supports `X` and `Y` of two discrete uniform measures with `n` atoms in
867- dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj`
868- different projections of the measures on random directions, and retains the
869- permutation that yields the lowest cost between `X` and `Y` (compared
870- in :math:`\mathbb{R}^d`).
877+ the supports `X` and `Y` of two discrete uniform measures with `n` and `m`
878+ atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through
879+ `n_proj` different projections of the measures on random directions, and
880+ retains the couplings that yields the lowest cost between `X` and `Y`
881+ (compared in :math:`\mathbb{R}^d`). When $n=m$, it gives
871882
872883 .. math::
873884 \mathrm{min\text{-}PS}_p^p(X, Y) \approx
@@ -888,7 +899,7 @@ def min_pivot_sliced(
888899 ----------
889900 X : array-like, shape (n, d)
890901 The first set of vectors.
891- Y : array-like, shape (n , d)
902+ Y : array-like, shape (m , d)
892903 The second set of vectors.
893904 a : ndarray of float64, shape (ns,), optional
894905 Source histogram (default is uniform weight)
@@ -918,10 +929,10 @@ def min_pivot_sliced(
918929
919930 Returns
920931 -------
921- perm : array-like , shape (n,)
922- The permutation that minimizes the cost .
923- min_cost : float
924- The minimum cost corresponding to the optimal permutation.
932+ plan : ndarray , shape (n, m) or coo_matrix if dense is False
933+ Optimal transportation matrix for the given parameters .
934+ cost : float
935+ The cost associated to the optimal permutation.
925936 log_dict : dict, optional
926937 A dictionary containing intermediate computations for logging purposes.
927938 Returned only if `log` is True.
@@ -935,16 +946,32 @@ def min_pivot_sliced(
935946
936947 .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
937948 Plans. arXiv preprint 2506.03661.
949+
950+ Examples
951+ --------
952+ >>> x=np.array([[3,3], [1,1]])
953+ >>> y=np.array([[2,2.5], [3,2]])
954+ >>> thetas=np.array([[1, 0], [0, 1]])
955+ >>> plan, cost = ot.expected_sliced(x, y, thetas)
956+ >>> plan
957+ [[0 0.5]
958+ [0.5 0]]
959+ >>> cost
960+ 2.125
938961 """
939962
963+ X , Y = list_to_array (X , Y )
964+ assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
965+ assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
966+
940967 assert (
941968 X .shape [1 ] == Y .shape [1 ]
942969 ), f"X ({ X .shape } ) and Y ({ Y .shape } ) must have the same number of columns"
943970
944971 nx = get_backend (X , Y ) if backend is None else backend
945972
946973 log_dict = {}
947- costs , G , log_dict_plans = sliced_plans (
974+ G , costs , log_dict_plans = sliced_plans (
948975 X ,
949976 Y ,
950977 a ,
@@ -1000,11 +1027,12 @@ def expected_sliced(
10001027 beta = 0.0 ,
10011028):
10021029 r"""
1003- Computes the Expected Sliced cost and plan between two `(n, d)`
1004- datasets `X` and `Y`. Given a set of `n_proj` projection directions,
1005- the expected sliced plan is obtained by averaging the `n_proj` 1d optimal
1006- transport plans between the projections of `X` and `Y` on each direction.
1007- Expected Sliced was introduced in [84] and further studied in [83].
1030+ Computes the Expected Sliced cost and plan between two datasets `X` and
1031+ `Y` of shapes `(n, d)` and `(m, d)`. Given a set of `n_proj` projection
1032+ directions, the expected sliced plan is obtained by averaging the `n_proj`
1033+ 1d optimal transport plans between the projections of `X` and `Y` on each
1034+ direction. Expected Sliced was introduced in [84] and further studied in
1035+ [83].
10081036
10091037 .. note::
10101038 The computation ignores potential ambiguities in the projections: if
@@ -1020,9 +1048,9 @@ def expected_sliced(
10201048 Parameters
10211049 ----------
10221050 X : torch.Tensor
1023- A tensor of shape (ns , d) representing the first set of vectors.
1051+ A tensor of shape (n , d) representing the first set of vectors.
10241052 Y : torch.Tensor
1025- A tensor of shape (nt , d) representing the second set of vectors.
1053+ A tensor of shape (m , d) representing the second set of vectors.
10261054 thetas : torch.Tensor, optional
10271055 A tensor of shape (n_proj, d) representing the projection directions.
10281056 If None, random directions will be generated. Default is None.
@@ -1031,7 +1059,7 @@ def expected_sliced(
10311059 order : int, optional
10321060 Power to elevate the norm. Default is 2.
10331061 dense: boolean, optional (default=True)
1034- If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt ).
1062+ If True, returns :math:`\gamma` as a dense ndarray of shape (n, m ).
10351063 Otherwise returns a sparse representation using scipy's `coo_matrix`
10361064 format.
10371065 log : bool, optional
@@ -1042,21 +1070,40 @@ def expected_sliced(
10421070
10431071 Returns
10441072 -------
1045- plan : torch.Tensor
1046- A tensor of shape (n_proj, n, n) representing the expected sliced plan.
1073+ plan : ndarray, shape (n, m) or coo_matrix if dense is False
1074+ Optimal transportation matrix for the given parameters.
1075+ cost : float
1076+ The cost associated to the optimal permutation.
10471077 log_dict : dict, optional
10481078 A dictionary containing intermediate computations for logging purposes.
10491079 Returned only if `log` is True.
10501080
10511081 References
10521082 ----------
10531083 .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
1054- Plans. arXiv preprint 2506.03661.
1055-
1084+ Plans. arXiv preprint 2506.03661.
10561085 .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi
1057- A., Kolouri, S. (2024). Expected Sliced Transport Plans. International
1058- Conference on Learning Representations.
1086+ A., Kolouri, S. (2024). Expected Sliced Transport Plans.
1087+ International Conference on Learning Representations.
1088+
1089+ Examples
1090+ --------
1091+ >>> x=np.array([[3,3], [1,1]])
1092+ >>> y=np.array([[2,2.5], [3,2]])
1093+ >>> thetas=np.array([[1, 0], [0, 1]])
1094+ >>> plan, cost = ot.expected_sliced(x, y, thetas)
1095+ >>> plan
1096+ [[0.25 0.25]
1097+ [0.25 0.25]]
1098+ >>> cost
1099+ 2.625
10591100 """
1101+
1102+ X , Y = list_to_array (X , Y )
1103+
1104+ assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
1105+ assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
1106+
10601107 assert (
10611108 X .shape [1 ] == Y .shape [1 ]
10621109 ), f"X ({ X .shape } ) and Y ({ Y .shape } ) must have the same number of columns"
@@ -1069,11 +1116,11 @@ def expected_sliced(
10691116 "to array assignment."
10701117 )
10711118
1072- ns = X .shape [0 ]
1073- nt = Y .shape [0 ]
1119+ n = X .shape [0 ]
1120+ m = Y .shape [0 ]
10741121
10751122 log_dict = {}
1076- costs , G , log_dict_plans = sliced_plans (
1123+ G , costs , log_dict_plans = sliced_plans (
10771124 X , Y , a , b , metric , p , thetas , n_proj = n_proj , log = True , backend = nx
10781125 )
10791126 if log :
@@ -1087,31 +1134,22 @@ def expected_sliced(
10871134 else : # uniform weights
10881135 if n_proj is None :
10891136 n_proj = thetas .shape [0 ]
1090- weights = nx .ones (n_proj , type_as = X ) / n_proj
1137+ weights = nx .ones (n_proj ) / n_proj
10911138
10921139 log_dict ["weights" ] = weights
10931140
10941141 weights = nx .concatenate ([G [i ].data * weights [i ] for i in range (len (G ))])
10951142 X_idx = nx .concatenate ([G [i ].row for i in range (len (G ))])
10961143 Y_idx = nx .concatenate ([G [i ].col for i in range (len (G ))])
1097- plan = nx .coo_matrix (
1098- weights ,
1099- X_idx ,
1100- Y_idx ,
1101- shape = (ns , nt ),
1102- type_as = X ,
1103- )
1144+ plan = nx .coo_matrix (weights , X_idx , Y_idx , shape = (n , m ))
11041145
11051146 if beta == 0.0 : # otherwise already computed above
11061147 cost = plan .multiply (dist (X , Y , metric = metric , p = p )).sum ()
11071148
11081149 if dense :
11091150 plan = nx .todense (plan )
11101151 elif str (nx ) == "jax" :
1111- warnings .warn (
1112- "JAX does not support sparse matrices, converting to\
1113- dense"
1114- )
1152+ warnings .warn ("JAX does not support sparse matrices, converting to dense" )
11151153 plan = nx .todense (plan )
11161154
11171155 if log :
0 commit comments