@@ -829,12 +829,12 @@ def dist(i, j):
829829 warnings .warn (
830830 "TensorFlow sparse indexing is limited, converting to dense"
831831 )
832- plan = [nx .todense (plan [k ]) for k in range (n_proj )]
833- idx_non_zeros = [nx .nonzero ( plan [k ]) for k in range (n_proj )]
832+ plan_dense = [nx .todense (plan [k ]) for k in range (n_proj )]
833+ idx_non_zeros = [nx .where ( plan_dense [k ] != 0 ) for k in range (n_proj )]
834834 costs = [
835835 nx .sum (
836836 dist (idx_non_zeros [k ][0 ], idx_non_zeros [k ][1 ])
837- * plan [k ][idx_non_zeros [k ][0 ], idx_non_zeros [k ][1 ]]
837+ * plan_dense [k ][idx_non_zeros [k ][0 ], idx_non_zeros [k ][1 ]]
838838 )
839839 for k in range (n_proj )
840840 ]
@@ -844,8 +844,8 @@ def dist(i, j):
844844 for k in range (n_proj )
845845 ]
846846
847- if dense and not (str (nx ) == "jax" or str ( nx ) == "tensorflow" ):
848- plan = [ nx . todense ( plan [ k ]) for k in range ( n_proj )]
847+ if dense and not (str (nx ) == "jax" ):
848+ plan = plan_dense . copy ()
849849 elif str (nx ) == "jax" and not is_perm :
850850 warnings .warn ("JAX does not support sparse matrices, converting to dense" )
851851 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
@@ -1101,7 +1101,7 @@ def expected_sliced(
11011101 >>> x=np.array([[3.,3.], [1.,1.]])
11021102 >>> y=np.array([[2.,2.5], [3.,2.]])
11031103 >>> thetas=np.array([[1, 0], [0, 1]])
1104- >>> plan, cost = expected_sliced(x, y, thetas)
1104+ >>> plan, cost = expected_sliced(x, y, thetas=thetas )
11051105 >>> plan
11061106 [[0.25 0.25]
11071107 [0.25 0.25]]
@@ -1138,7 +1138,7 @@ def expected_sliced(
11381138
11391139 log_dict = {}
11401140 G , costs , log_dict_plans = sliced_plans (
1141- X , Y , a , b , metric , p , thetas , n_proj = n_proj , log = True
1141+ X , Y , a , b , metric , p , thetas , n_proj = n_proj , log = True , dense = False
11421142 )
11431143 if log :
11441144 log_dict = {"thetas" : log_dict_plans ["thetas" ], "costs" : costs , "G" : G }
@@ -1154,7 +1154,6 @@ def expected_sliced(
11541154 weights = nx .ones (n_proj ) / n_proj
11551155
11561156 log_dict ["weights" ] = weights
1157-
11581157 weights = nx .concatenate ([G [i ].data * weights [i ] for i in range (len (G ))])
11591158 X_idx = nx .concatenate ([G [i ].row for i in range (len (G ))])
11601159 Y_idx = nx .concatenate ([G [i ].col for i in range (len (G ))])
0 commit comments