@@ -766,6 +766,7 @@ def sliced_plans(
766766 assert (
767767 X .shape [1 ] == Y .shape [1 ]
768768 ), f"X ({ X .shape } ) and Y ({ Y .shape } ) must have the same number of columns"
769+
769770 if metric == "euclidean" :
770771 p = 2
771772 elif metric == "cityblock" :
@@ -818,13 +819,18 @@ def dist(i, j):
818819 X_theta , Y_theta , a , b , p , require_sort = True , return_plan = True
819820 )
820821
821- if str (nx ) == "jax" : # dense computation for jax
822+ if str (nx ) == "jax" or str ( nx ) == "tensorflow" :
822823 if not dense :
823- warnings .warn (
824- "JAX does not support sparse matrices, converting to dense"
825- )
824+ if str (nx ) == "jax" :
825+ warnings .warn (
826+ "JAX does not support sparse matrices, converting to dense"
827+ )
828+ else :
829+ warnings .warn (
830+ "TensorFlow sparse indexing is limited, converting to dense"
831+ )
826832 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
827- idx_non_zeros = [np .nonzero (plan [k ]) for k in range (n_proj )]
833+ idx_non_zeros = [nx .nonzero (plan [k ]) for k in range (n_proj )]
828834 costs = [
829835 nx .sum (
830836 dist (idx_non_zeros [k ][0 ], idx_non_zeros [k ][1 ])
@@ -833,25 +839,22 @@ def dist(i, j):
833839 for k in range (n_proj )
834840 ]
835841 else :
836- if str (nx ) == "tensorflow" : # tf does not support multiple indexing
837- plan = [plan [k ].tocsr ().tocoo () for k in range (n_proj )]
838-
839842 costs = [
840843 nx .sum (dist (plan [k ].row , plan [k ].col ) * plan [k ].data )
841844 for k in range (n_proj )
842845 ]
843846
844- if dense and not str (nx ) == "jax" :
847+ if dense and not ( str (nx ) == "jax" or str ( nx ) == "tensorflow" ) :
845848 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
846- elif str (nx ) == "jax" :
849+ elif str (nx ) == "jax" and not is_perm :
847850 warnings .warn ("JAX does not support sparse matrices, converting to dense" )
848851 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
849852
850853 if log :
851854 log_dict = {"X_theta" : X_theta , "Y_theta" : Y_theta , "thetas" : thetas }
852- return plan , costs , log_dict
855+ return plan , nx . stack ( costs ) , log_dict
853856 else :
854- return plan , costs
857+ return plan , nx . stack ( costs )
855858
856859
857860def min_pivot_sliced (
@@ -945,7 +948,7 @@ def min_pivot_sliced(
945948 >>> x=np.array([[3.,3.], [1.,1.]])
946949 >>> y=np.array([[2.,2.5], [3.,2.]])
947950 >>> thetas=np.array([[1, 0], [0, 1]])
948- >>> plan, cost = ot.expected_sliced (x, y, thetas)
951+ >>> plan, cost = min_pivot_sliced (x, y, thetas)
949952 >>> plan
950953 [[0 0.5]
951954 [0.5 0]]
@@ -984,6 +987,7 @@ def min_pivot_sliced(
984987 warm_theta = warm_theta ,
985988 log = True ,
986989 )
990+
987991 pos_min = nx .argmin (costs )
988992 cost = costs [pos_min ]
989993 plan = G [pos_min ]
@@ -1097,7 +1101,7 @@ def expected_sliced(
10971101 >>> x=np.array([[3.,3.], [1.,1.]])
10981102 >>> y=np.array([[2.,2.5], [3.,2.]])
10991103 >>> thetas=np.array([[1, 0], [0, 1]])
1100- >>> plan, cost = ot. expected_sliced(x, y, thetas)
1104+ >>> plan, cost = expected_sliced(x, y, thetas)
11011105 >>> plan
11021106 [[0.25 0.25]
11031107 [0.25 0.25]]
0 commit comments