@@ -756,6 +756,12 @@ def sliced_plans(
756756 assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
757757 assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
758758
759+ assert metric in ("minkowski" , "euclidean" , "cityblock" , "sqeuclidean" ), (
760+ "Sliced plans work only with metrics "
761+ + "from the following list: "
762+ + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
763+ )
764+
759765 assert (
760766 X .shape [1 ] == Y .shape [1 ]
761767 ), f"X ({ X .shape } ) and Y ({ Y .shape } ) must have the same number of columns"
@@ -776,7 +782,7 @@ def sliced_plans(
776782 do_draw_thetas = thetas is None
777783 if do_draw_thetas : # create thetas (n_proj, d)
778784 assert n_proj is not None , "n_proj must be specified if thetas is None"
779- thetas = get_random_projections (d , n_proj , backend = nx ).T
785+ thetas = get_random_projections (d , n_proj , backend = nx , type_as = X ).T
780786
781787 if warm_theta is not None :
782788 thetas = nx .concatenate ([thetas , warm_theta [:, None ].T ], axis = 0 )
@@ -787,8 +793,7 @@ def sliced_plans(
787793 X_theta = X @ thetas .T # shape (n, n_proj)
788794 Y_theta = Y @ thetas .T # shape (m, n_proj)
789795
790- if is_perm :
791- # we compute maps (permutations)
796+ if is_perm : # we compute maps (permutations)
792797 # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj]
793798 sigma = nx .argsort (X_theta , axis = 0 ) # (n, n_proj)
794799 tau = nx .argsort (Y_theta , axis = 0 ) # (m, n_proj)
@@ -803,17 +808,12 @@ def sliced_plans(
803808 )
804809 for k in range (n_proj )
805810 ]
806- elif metric = = "sqeuclidean" :
811+ else : # metric = "sqeuclidean":
807812 costs = [
808813 nx .sum ((nx .sum ((X [sigma [:, k ]] - Y [tau [:, k ]]) ** 2 , axis = 1 )) / n )
809814 for k in range (n_proj )
810815 ]
811- else :
812- raise ValueError (
813- "Sliced plans work only with metrics "
814- + "from the following list: "
815- + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
816- )
816+
817817 a = nx .ones (n ) / n
818818 plan = [
819819 nx .coo_matrix (a , sigma [:, k ], tau [:, k ], shape = (n , m ), type_as = a )
@@ -825,6 +825,35 @@ def sliced_plans(
825825 X_theta , Y_theta , a , b , p , require_sort = True , return_plan = True
826826 )
827827
828+ plan = plan .tocsr ().tocoo () # especially for tensorflow compatibility
829+
830+ if str (nx ) == "jax" :
831+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
832+ if not dense :
833+ warnings .warn (
834+ "JAX does not support sparse matrices, converting" "to dense"
835+ )
836+
837+ costs = [
838+ nx .sum (
839+ (
840+ (
841+ nx .sum (
842+ nx .abs (
843+ X [np .nonzero (plan [k ])[0 ]]
844+ - Y [np .nonzero (plan [k ])[1 ]]
845+ )
846+ ** p ,
847+ axis = 1 ,
848+ )
849+ )
850+ ** (1 / p )
851+ )
852+ * plan [np .nonzero (plan [k ])]
853+ )
854+ for k in range (n_proj )
855+ ]
856+
828857 if metric in ("minkowski" , "euclidean" , "cityblock" ):
829858 costs = [
830859 nx .sum (
@@ -836,26 +865,17 @@ def sliced_plans(
836865 )
837866 for k in range (n_proj )
838867 ]
839- elif metric == "sqeuclidean" :
868+ else : # metric = "sqeuclidean"
840869 costs = [
841870 nx .sum (
842871 (nx .sum ((X [plan [k ].row ] - Y [plan [k ].col ]) ** 2 , axis = 1 ))
843872 * plan [k ].data
844873 )
845874 for k in range (n_proj )
846875 ]
847- else :
848- raise ValueError (
849- "Sliced plans work only with metrics "
850- + "from the following list: "
851- + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
852- )
853876
854877 if dense :
855878 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
856- elif str (nx ) == "jax" :
857- warnings .warn ("JAX does not support sparse matrices, converting to dense" )
858- plan = [nx .todense (plan [k ]) for k in range (n_proj )]
859879
860880 if log :
861881 log_dict = {"X_theta" : X_theta , "Y_theta" : Y_theta , "thetas" : thetas }
0 commit comments