@@ -2120,16 +2120,12 @@ def make_node(self, x, ilist):
21202120 out_shape = (ilist_ .type .shape [0 ], * x_ .type .shape [1 :])
21212121 return Apply (self , [x_ , ilist_ ], [TensorType (dtype = x .dtype , shape = out_shape )()])
21222122
2123- def perform (self , node , inp , out_ ):
2123+ def perform (self , node , inp , output_storage ):
21242124 x , i = inp
2125- (out ,) = out_
2126- # Copy always implied by numpy advanced indexing semantic.
2127- if out [0 ] is not None and out [0 ].shape == (len (i ),) + x .shape [1 :]:
2128- o = out [0 ]
2129- else :
2130- o = None
21312125
2132- out [0 ] = x .take (i , axis = 0 , out = o )
2126+ # Numpy take is always slower when out is provided
2127+ # https://github.com/numpy/numpy/issues/28636
2128+ output_storage [0 ][0 ] = x .take (i , axis = 0 , out = None )
21332129
21342130 def connection_pattern (self , node ):
21352131 rval = [[True ], * ([False ] for _ in node .inputs [1 :])]
@@ -2174,42 +2170,83 @@ def c_code(self, node, name, input_names, output_names, sub):
21742170 "c_code defined for AdvancedSubtensor1, not for child class" ,
21752171 type (self ),
21762172 )
2173+ x , idxs = node .inputs
2174+ if self ._idx_may_be_invalid (x , idxs ):
2175+ mode = "NPY_RAISE"
2176+ else :
2177+ # We can know ahead of time that all indices are valid, so we can use a faster mode
2178+ mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
2179+
21772180 a_name , i_name = input_names [0 ], input_names [1 ]
21782181 output_name = output_names [0 ]
21792182 fail = sub ["fail" ]
2180- return f"""
2181- if ({ output_name } != NULL) {{
2182- npy_intp nd, i, *shape;
2183- nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM({ i_name } ) - 1;
2184- if (PyArray_NDIM({ output_name } ) != nd) {{
2183+ if mode == "NPY_RAISE" :
2184+ # numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
2185+ # We can remove this special case after https://github.com/numpy/numpy/issues/28636
2186+ manage_pre_allocated_out = f"""
2187+ if ({ output_name } != NULL) {{
2188+ // Numpy TakeFrom is always slower when copying
2189+ // https://github.com/numpy/numpy/issues/28636
21852190 Py_CLEAR({ output_name } );
21862191 }}
2187- else {{
2188- shape = PyArray_DIMS( { output_name } );
2189- for (i = 0; i < PyArray_NDIM( { i_name } ); i++) {{
2190- if (shape[i] != PyArray_DIMS( { i_name } )[i] ) {{
2191- Py_CLEAR( { output_name } ) ;
2192- break;
2193- }}
2192+ """
2193+ else :
2194+ manage_pre_allocated_out = f"""
2195+ if ({ output_name } != NULL ) {{
2196+ npy_intp nd = PyArray_NDIM( { a_name } ) + PyArray_NDIM( { i_name } ) - 1 ;
2197+ if (PyArray_NDIM( { output_name } ) != nd) {{
2198+ Py_CLEAR( { output_name } );
21942199 }}
2195- if ({ output_name } != NULL) {{
2196- for (; i < nd; i++) {{
2197- if (shape[i] != PyArray_DIMS({ a_name } )[
2198- i-PyArray_NDIM({ i_name } )+1]) {{
2200+ else {{
2201+ int i;
2202+ npy_intp* shape = PyArray_DIMS({ output_name } );
2203+ for (i = 0; i < PyArray_NDIM({ i_name } ); i++) {{
2204+ if (shape[i] != PyArray_DIMS({ i_name } )[i]) {{
21992205 Py_CLEAR({ output_name } );
22002206 break;
22012207 }}
22022208 }}
2209+ if ({ output_name } != NULL) {{
2210+ for (; i < nd; i++) {{
2211+ if (shape[i] != PyArray_DIMS({ a_name } )[i-PyArray_NDIM({ i_name } )+1]) {{
2212+ Py_CLEAR({ output_name } );
2213+ break;
2214+ }}
2215+ }}
2216+ }}
22032217 }}
22042218 }}
2205- }}
2219+ """
2220+
2221+ return f"""
2222+ { manage_pre_allocated_out }
22062223 { output_name } = (PyArrayObject*)PyArray_TakeFrom(
2207- { a_name } , (PyObject*){ i_name } , 0, { output_name } , NPY_RAISE );
2224+ { a_name } , (PyObject*){ i_name } , 0, { output_name } , { mode } );
22082225 if ({ output_name } == NULL) { fail } ;
22092226 """
22102227
22112228 def c_code_cache_version (self ):
2212- return (4 ,)
2229+ return (5 ,)
2230+
2231+ @staticmethod
2232+ def _idx_may_be_invalid (x , idx ) -> bool :
2233+ if idx .type .shape [0 ] == 0 :
2234+ # Empty index is always valid
2235+ return False
2236+
2237+ if x .type .shape [0 ] is None :
2238+ # We can't know if in index is valid if we don't know the length of x
2239+ return True
2240+
2241+ if not isinstance (idx , Constant ):
2242+ # This is conservative, but we don't try to infer lower/upper bound symbolically
2243+ return True
2244+
2245+ shape0 = x .type .shape [0 ]
2246+ min_idx , max_idx = idx .data .min (), idx .data .max ()
2247+ return not (min_idx >= 0 or min_idx >= - shape0 ) and (
2248+ max_idx < 0 or max_idx < shape0
2249+ )
22132250
22142251
22152252advanced_subtensor1 = AdvancedSubtensor1 ()
0 commit comments