@@ -385,14 +385,8 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
385385 // always get a float, so that we don't have to resolve the dtype later
386386 mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
387387 shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
388- mp_float_t * rarray = NULL ;
389- ndarray_obj_t * results = NULL ;
390- if ((axis != mp_const_none ) && (ndarray -> ndim > 1 )) {
391- results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_FLOAT );
392- rarray = results -> array ;
393- } else {
394- rarray = m_new (mp_float_t , 1 );
395- }
388+ ndarray_obj_t * results = ndarray_new_dense_ndarray (_shape_strides .ndim , _shape_strides .shape , NDARRAY_FLOAT );
389+ mp_float_t * rarray = (mp_float_t * )results -> array ;
396390
397391 #if ULAB_MAX_DIMS > 3
398392 size_t i = 0 ;
@@ -418,28 +412,26 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
418412 l ++ ;
419413 } while (l < _shape_strides .shape [0 ]);
420414 * rarray = MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 ));
421- if (results != NULL ) {
422- rarray ++ ;
423- }
424415 #if ULAB_MAX_DIMS > 1
416+ rarray += _shape_strides .increment ;
425417 array -= _shape_strides .strides [0 ] * _shape_strides .shape [0 ];
426418 array += _shape_strides .strides [ULAB_MAX_DIMS - 1 ];
427419 k ++ ;
428420 } while (k < _shape_strides .shape [ULAB_MAX_DIMS - 1 ]);
429421 #endif
430422 #if ULAB_MAX_DIMS > 2
431- array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
423+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
432424 array += _shape_strides .strides [ULAB_MAX_DIMS - 2 ];
433425 j ++ ;
434426 } while (j < _shape_strides .shape [ULAB_MAX_DIMS - 2 ]);
435427 #endif
436428 #if ULAB_MAX_DIMS > 3
437- array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
429+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
438430 array += _shape_strides .strides [ULAB_MAX_DIMS - 3 ];
439431 i ++ ;
440432 } while (i < _shape_strides .shape [ULAB_MAX_DIMS - 3 ]);
441433 #endif
442- if (results == NULL ) {
434+ if (results -> ndim == 0 ) {
443435 return mp_obj_new_float (* rarray );
444436 }
445437 return results ;
0 commit comments