@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
22622262 check_input = False
22632263 params_type = ParamsType (inplace = ps .bool , set_instead_of_inc = ps .bool )
22642264
2265+ _runtime_broadcast_error_msg = (
2266+ "Runtime broadcasting not allowed. "
2267+ "AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
2268+ "If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
2269+ )
2270+
22652271 def __init__ (self , inplace = False , set_instead_of_inc = False ):
22662272 self .inplace = bool (inplace )
22672273 self .set_instead_of_inc = bool (set_instead_of_inc )
@@ -2333,6 +2339,9 @@ def copy_of_x(self, x):
23332339 NPY_ARRAY_ENSURECOPY, NULL)"""
23342340
23352341 def c_support_code (self , ** kwargs ):
2342+ if numpy_version < "1.8.0" or using_numpy_2 :
2343+ return None
2344+
23362345 types = [
23372346 "npy_" + t
23382347 for t in [
@@ -2523,15 +2532,117 @@ def gen_num(typen):
25232532 return code
25242533
25252534 def c_code (self , node , name , input_names , output_names , sub ):
2526- if numpy_version < "1.8.0" or using_numpy_2 :
2527- raise NotImplementedError
2528-
25292535 x , y , idx = input_names
2530- out = output_names [ 0 ]
2536+ [ out ] = output_names
25312537 copy_of_x = self .copy_of_x (x )
25322538 params = sub ["params" ]
25332539 fail = sub ["fail" ]
25342540
2541+ x_ , y_ , idx_ = node .inputs
2542+ y_cdtype = y_ .type .dtype_specs ()[1 ]
2543+ idx_cdtype = idx_ .type .dtype_specs ()[1 ]
2544+ out_cdtype = node .outputs [0 ].type .dtype_specs ()[1 ]
2545+ y_bcast = y_ .type .broadcastable != idx_ .type .broadcastable
2546+ if (
2547+ x_ .type .ndim == 1
2548+ and y_ .type .ndim == 1
2549+ and not y_bcast
2550+ and x_ .type .dtype not in complex_dtypes
2551+ and y_ .type .dtype not in complex_dtypes
2552+ ):
2553+ # Simple implementation for vector x, y cases
2554+ idx_may_be_neg = not (isinstance (idx_ , Constant ) and idx_ .data .min () >= 0 )
2555+ idx_may_be_invalid = AdvancedSubtensor1 ._idx_may_be_invalid (x_ , idx_ )
2556+ shape0 = x_ .type .shape [0 ]
2557+ # This is used to make sure that when we trust the indices to be valid
2558+ # we are not fooled by a wrong static shape
2559+ # We mention x to the user in error messages but we work (and make checks) on out,
2560+ # which should be x or a copy of it
2561+ unexpected_shape0 = (
2562+ f"PyArray_SHAPE({ out } )[0] != { shape0 } " if shape0 is not None else "0"
2563+ )
2564+
2565+ op = "=" if self .set_instead_of_inc else "+="
2566+ code = f"""
2567+ if ({ params } ->inplace)
2568+ {{
2569+ if ({ x } != { out } )
2570+ {{
2571+ Py_XDECREF({ out } );
2572+ Py_INCREF({ x } );
2573+ { out } = { x } ;
2574+ }}
2575+ }}
2576+ else
2577+ {{
2578+ Py_XDECREF({ out } );
2579+ { out } = { copy_of_x } ;
2580+ if (!{ out } ) {{
2581+ // Exception already set
2582+ { fail }
2583+ }}
2584+ }}
2585+
2586+ if (PyArray_NDIM({ out } ) != 1) {{
2587+ PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) ndim should be 1, got %d", PyArray_NDIM({ out } ));
2588+ { fail }
2589+ }}
2590+ if ({ unexpected_shape0 } ) {{
2591+ PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) shape should be { shape0 } , got %d", PyArray_SHAPE({ out } )[0]);
2592+ { fail }
2593+ }}
2594+ if (PyArray_NDIM({ idx } ) != 1) {{
2595+ PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim should be 1, got %d", PyArray_NDIM({ idx } ));
2596+ { fail }
2597+ }}
2598+ if (PyArray_NDIM({ y } ) != 1) {{
2599+ PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: second input (y) ndim should be 1, got %d", PyArray_NDIM({ y } ));
2600+ { fail }
2601+ }}
2602+ if (PyArray_SHAPE({ y } )[0] != PyArray_SHAPE({ idx } )[0]) {{
2603+ if ((PyArray_NDIM({ y } ) == 1) && (PyArray_SHAPE({ y } )[0] == 1)){{
2604+ PyErr_Format(PyExc_ValueError, "{ self ._runtime_broadcast_error_msg } ");
2605+ }} else {{
2606+ PyErr_Format(PyExc_ValueError,
2607+ "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match: %d, %d",
2608+ PyArray_SHAPE({ y } )[0], PyArray_SHAPE({ idx } )[0]);
2609+ }}
2610+ { fail }
2611+ }}
2612+
2613+ {{
2614+ npy_intp out_shape0 = PyArray_SHAPE({ out } )[0];
2615+ { out_cdtype } * out_data = ({ out_cdtype } *)PyArray_DATA({ out } );
2616+ { y_cdtype } * y_data = ({ y_cdtype } *)PyArray_DATA({ y } );
2617+ { idx_cdtype } * idx_data = ({ idx_cdtype } *)PyArray_DATA({ idx } );
2618+ npy_intp n = PyArray_SHAPE({ idx } )[0];
2619+ npy_intp out_jump = PyArray_STRIDES({ out } )[0] / PyArray_ITEMSIZE({ out } );
2620+ npy_intp y_jump = PyArray_STRIDES({ y } )[0] / PyArray_ITEMSIZE({ y } );
2621+ npy_intp idx_jump = PyArray_STRIDES({ idx } )[0] / PyArray_ITEMSIZE({ idx } );
2622+
2623+ for(int i = 0; i < n; i++){{
2624+ { idx_cdtype } idx = idx_data[i * idx_jump];
2625+ if ({ int (idx_may_be_neg )} ){{
2626+ if (idx < 0) {{
2627+ idx += out_shape0;
2628+ }}
2629+ }}
2630+ if ({ int (idx_may_be_invalid )} ){{
2631+ if ((idx < 0) || (idx >= out_shape0)) {{
2632+ PyErr_Format(PyExc_IndexError,"index %d out of bounds for array with shape %d", idx_data[i * idx_jump], out_shape0);
2633+ { fail }
2634+ }}
2635+ }}
2636+ out_data[idx * out_jump] { op } y_data[i * y_jump];
2637+ }}
2638+
2639+ }}
2640+ """
2641+ return code
2642+
2643+ if numpy_version < "1.8.0" or using_numpy_2 :
2644+ raise NotImplementedError
2645+
25352646 return f"""
25362647 PyObject* rval = NULL;
25372648 if ({ params } ->inplace)
@@ -2559,22 +2670,45 @@ def c_code(self, node, name, input_names, output_names, sub):
25592670 """
25602671
25612672 def c_code_cache_version (self ):
2562- return (8 ,)
2673+ return (9 ,)
2674+
2675+ def _check_runtime_broadcasting (
2676+ self , node : Apply , x : np .ndarray , y : np .ndarray , idx : np .ndarray
2677+ ) -> None :
2678+ if y .ndim > 0 :
2679+ y_pt_bcast = node .inputs [1 ].broadcastable # type: ignore
2680+
2681+ if not y_pt_bcast [0 ] and y .shape [0 ] == 1 and y .shape [0 ] != idx .shape [0 ]:
2682+ # Attempting to broadcast with index
2683+ raise ValueError (self ._runtime_broadcast_error_msg )
2684+ if any (
2685+ not y_bcast and y_dim == 1 and y_dim != x_dim
2686+ for y_bcast , y_dim , x_dim in zip (
2687+ reversed (y_pt_bcast ),
2688+ reversed (y .shape ),
2689+ reversed (x .shape ),
2690+ strict = False ,
2691+ )
2692+ ):
2693+ # Attempting to broadcast with buffer
2694+ raise ValueError (self ._runtime_broadcast_error_msg )
2695+
2696+ def perform (self , node , inputs , output_storage ):
2697+ x , y , idx = inputs
25632698
2564- def perform (self , node , inp , out_ ):
2565- x , y , idx = inp
2566- (out ,) = out_
25672699 if not self .inplace :
25682700 x = x .copy ()
25692701
2702+ self ._check_runtime_broadcasting (node , x , y , idx )
2703+
25702704 if self .set_instead_of_inc :
25712705 x [idx ] = y
25722706 else :
25732707 # In Numpy, `x[idx] += y` doesn't work if the same index is present
25742708 # many times: it does it only once.
25752709 np .add .at (x , idx , y )
25762710
2577- out [0 ] = x
2711+ output_storage [ 0 ] [0 ] = x
25782712
25792713 def infer_shape (self , fgraph , node , ishapes ):
25802714 x , y , ilist = ishapes
0 commit comments