@@ -746,7 +746,6 @@ def sliced_plans(
746746 Returned only if `log` is True.
747747 """
748748
749- X , Y = list_to_array (X , Y )
750749 nx = get_backend (X , Y ) if backend is None else backend
751750 assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
752751 assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
@@ -903,9 +902,9 @@ def min_pivot_sliced(
903902 The first set of vectors.
904903 Y : array-like, shape (m, d)
905904 The second set of vectors.
906- a : ndarray of float64, shape (ns ,), optional
905+ a : ndarray of float64, shape (n ,), optional
907906 Source histogram (default is uniform weight)
908- b : ndarray of float64, shape (nt ,), optional
907+ b : ndarray of float64, shape (m ,), optional
909908 Target histogram (default is uniform weight)
910909 thetas : array-like, shape (n_proj, d), optional
911910 The projection directions. If None, random directions will be generated
@@ -962,7 +961,6 @@ def min_pivot_sliced(
962961 2.125
963962 """
964963
965- X , Y = list_to_array (X , Y )
966964 nx = get_backend (X , Y ) if backend is None else backend
967965 assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
968966 assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
@@ -987,7 +985,7 @@ def min_pivot_sliced(
987985 log = True ,
988986 backend = nx ,
989987 )
990- pos_min = np .argmin (costs )
988+ pos_min = nx .argmin (costs )
991989 cost = costs [pos_min ]
992990 plan = G [pos_min ]
993991
@@ -1050,23 +1048,33 @@ def expected_sliced(
10501048
10511049 Parameters
10521050 ----------
1053- X : torch.Tensor
1054- A tensor of shape (n, d) representing the first set of vectors.
1055- Y : torch.Tensor
1056- A tensor of shape (m, d) representing the second set of vectors.
1051+ X : array-like, shape (n, d)
1052+ The first set of vectors.
1053+ Y : array-like, shape (m, d)
1054+ The second set of vectors.
1055+ a : ndarray of float64, shape (n,), optional
1056+ Source histogram (default is uniform weight)
1057+ b : ndarray of float64, shape (m,), optional
1058+ Target histogram (default is uniform weight)
10571059 thetas : torch.Tensor, optional
10581060 A tensor of shape (n_proj, d) representing the projection directions.
10591061 If None, random directions will be generated. Default is None.
1062+ metric: str, optional (default='sqeuclidean')
1063+ Metric to be used. Only works with either of the strings
1064+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
1065+ p: float, optional (default=2)
1066+ The p-norm to apply for if metric='minkowski'
10601067 n_proj : int, optional
10611068 The number of projection directions. Required if thetas is None.
1062- order : int, optional
1063- Power to elevate the norm. Default is 2.
10641069 dense: boolean, optional (default=True)
10651070 If True, returns :math:`\gamma` as a dense ndarray of shape (n, m).
10661071 Otherwise returns a sparse representation using scipy's `coo_matrix`
10671072 format.
10681073 log : bool, optional
10691074 If True, returns additional logging information. Default is False.
1075+ backend : ot.backend, optional
1076+ Backend to use for computations. If None, the backend is inferred from
1077+ the input arrays. Default is None.
10701078 beta : float, optional
10711079 Inverse-temperature parameter which weights each projection's
10721080 contribution to the expected plan. Default is 0 (uniform weighting).
@@ -1102,7 +1110,6 @@ def expected_sliced(
11021110 2.625
11031111 """
11041112
1105- X , Y = list_to_array (X , Y )
11061113 nx = get_backend (X , Y ) if backend is None else backend
11071114
11081115 assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
0 commit comments