@@ -190,6 +190,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
190190
191191static mp_obj_t numerical_sum_mean_std_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , uint8_t optype , size_t ddof ) {
192192 uint8_t * array = (uint8_t * )ndarray -> array ;
193+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
193194
194195 if (axis == mp_const_none ) {
195196 // work with the flattened array
@@ -223,26 +224,26 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
223224 S = s ;
224225 }
225226 M = m ;
226- array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
227+ array += _shape_strides . strides [ULAB_MAX_DIMS - 1 ];
227228 l ++ ;
228- } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
229+ } while (l < _shape_strides . shape [ULAB_MAX_DIMS - 1 ]);
229230 #if ULAB_MAX_DIMS > 1
230- array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
231- array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
231+ array -= _shape_strides . strides [ULAB_MAX_DIMS - 1 ] * _shape_strides . shape [ULAB_MAX_DIMS - 1 ];
232+ array += _shape_strides . strides [ULAB_MAX_DIMS - 2 ];
232233 k ++ ;
233- } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
234+ } while (k < _shape_strides . shape [ULAB_MAX_DIMS - 2 ]);
234235 #endif
235236 #if ULAB_MAX_DIMS > 2
236- array -= ndarray -> strides [ULAB_MAX_DIMS - 2 ] * ndarray -> shape [ULAB_MAX_DIMS - 2 ];
237- array += ndarray -> strides [ULAB_MAX_DIMS - 3 ];
237+ array -= _shape_strides . strides [ULAB_MAX_DIMS - 2 ] * _shape_strides . shape [ULAB_MAX_DIMS - 2 ];
238+ array += _shape_strides . strides [ULAB_MAX_DIMS - 3 ];
238239 j ++ ;
239- } while (j < ndarray -> shape [ULAB_MAX_DIMS - 3 ]);
240+ } while (j < _shape_strides . shape [ULAB_MAX_DIMS - 3 ]);
240241 #endif
241242 #if ULAB_MAX_DIMS > 3
242- array -= ndarray -> strides [ULAB_MAX_DIMS - 3 ] * ndarray -> shape [ULAB_MAX_DIMS - 3 ];
243- array += ndarray -> strides [ULAB_MAX_DIMS - 4 ];
243+ array -= _shape_strides . strides [ULAB_MAX_DIMS - 3 ] * _shape_strides . shape [ULAB_MAX_DIMS - 3 ];
244+ array += _shape_strides . strides [ULAB_MAX_DIMS - 4 ];
244245 i ++ ;
245- } while (i < ndarray -> shape [ULAB_MAX_DIMS - 4 ]);
246+ } while (i < _shape_strides . shape [ULAB_MAX_DIMS - 4 ]);
246247 #endif
247248 if (optype == NUMERICAL_SUM ) {
248249 // numpy returns an integer for integer input types
@@ -258,18 +259,12 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
258259 return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(S / (ndarray -> len - ddof )));
259260 }
260261 } else {
261- shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
262- // if(ndarray->ndim == 1) {
263- // // if we have the single dimension, axis = 0 is equivalent to axis = None
264- // // the call to tools_reduce_axes() has made sure that axis = 0
265- // return numerical_sum_mean_std_ndarray(ndarray, mp_const_none, optype, ddof);
266- // }
267262 ndarray_obj_t * results = NULL ;
268263 uint8_t * rarray = NULL ;
269-
264+ mp_float_t * farray = NULL ;
270265 if (optype == NUMERICAL_SUM ) {
271- results = ndarray_new_dense_ndarray (MAX ( 1 , ndarray -> ndim - 1 ) , _shape_strides .shape , ndarray -> dtype );
272- rarray = (uint8_t * )results -> array ;
266+ results = ndarray_new_dense_ndarray (_shape_strides . ndim , _shape_strides .shape , ndarray -> dtype );
267+ rarray = (uint8_t * )results -> array ;
273268 // TODO: numpy promotes the output to the highest integer type
274269 if (ndarray -> dtype == NDARRAY_UINT8 ) {
275270 RUN_SUM (uint8_t , array , results , rarray , _shape_strides );
@@ -282,37 +277,37 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
282277 } else {
283278 // for floats, the sum might be inaccurate with the naive summation
284279 // call mean, and multiply with the number of samples
285- mp_float_t * r = (mp_float_t * )results -> array ;
286- RUN_MEAN_STD (mp_float_t , array , r , _shape_strides , 0.0 , 0 );
280+ farray = (mp_float_t * )results -> array ;
281+ RUN_MEAN_STD (mp_float_t , array , farray , _shape_strides , 0.0 , 0 );
287282 mp_float_t norm = (mp_float_t )_shape_strides .shape [0 ];
288283 // re-wind the array here
289- r = (mp_float_t * )results -> array ;
284+ farray = (mp_float_t * )results -> array ;
290285 for (size_t i = 0 ; i < results -> len ; i ++ ) {
291- * r ++ *= norm ;
286+ * farray ++ *= norm ;
292287 }
293288 }
294289 } else {
295290 bool isStd = optype == NUMERICAL_STD ? 1 : 0 ;
296- results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_FLOAT );
291+ results = ndarray_new_dense_ndarray (_shape_strides .ndim , _shape_strides .shape , NDARRAY_FLOAT );
292+ farray = (mp_float_t * )results -> array ;
297293 // we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
298294 if ((optype == NUMERICAL_STD ) && (_shape_strides .shape [0 ] <= ddof )) {
299295 return MP_OBJ_FROM_PTR (results );
300296 }
301297 mp_float_t div = optype == NUMERICAL_STD ? (mp_float_t )(_shape_strides .shape [0 ] - ddof ) : 0.0 ;
302- mp_float_t * rarray = (mp_float_t * )results -> array ;
303298 if (ndarray -> dtype == NDARRAY_UINT8 ) {
304- RUN_MEAN_STD (uint8_t , array , rarray , _shape_strides , div , isStd );
299+ RUN_MEAN_STD (uint8_t , array , farray , _shape_strides , div , isStd );
305300 } else if (ndarray -> dtype == NDARRAY_INT8 ) {
306- RUN_MEAN_STD (int8_t , array , rarray , _shape_strides , div , isStd );
301+ RUN_MEAN_STD (int8_t , array , farray , _shape_strides , div , isStd );
307302 } else if (ndarray -> dtype == NDARRAY_UINT16 ) {
308- RUN_MEAN_STD (uint16_t , array , rarray , _shape_strides , div , isStd );
303+ RUN_MEAN_STD (uint16_t , array , farray , _shape_strides , div , isStd );
309304 } else if (ndarray -> dtype == NDARRAY_INT16 ) {
310- RUN_MEAN_STD (int16_t , array , rarray , _shape_strides , div , isStd );
305+ RUN_MEAN_STD (int16_t , array , farray , _shape_strides , div , isStd );
311306 } else {
312- RUN_MEAN_STD (mp_float_t , array , rarray , _shape_strides , div , isStd );
307+ RUN_MEAN_STD (mp_float_t , array , farray , _shape_strides , div , isStd );
313308 }
314309 }
315- if (ndarray -> ndim == 1 ) { // return a scalar here
310+ if (results -> ndim == 0 ) { // return a scalar here
316311 return mp_binary_get_val_array (results -> dtype , results -> array , 0 );
317312 }
318313 return MP_OBJ_FROM_PTR (results );
0 commit comments