@@ -354,7 +354,23 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354354//| ...
355355//|
356356
357- static mp_obj_t linalg_norm (mp_obj_t x ) {
357+ static mp_obj_t linalg_norm (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
358+ static const mp_arg_t allowed_args [] = {
359+ { MP_QSTR_ , MP_ARG_REQUIRED | MP_ARG_OBJ , { .u_rom_obj = mp_const_none } } ,
360+ { MP_QSTR_axis , MP_ARG_OBJ , { .u_rom_obj = mp_const_none } },
361+ };
362+
363+ mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
364+ mp_arg_parse_all (n_args , pos_args , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
365+
366+ mp_obj_t x = args [0 ].u_obj ;
367+ mp_obj_t axis = args [1 ].u_obj ;
368+ if ((axis != mp_const_none ) && (!MP_OBJ_IS_INT (axis ))) {
369+ mp_raise_TypeError (translate ("axis must be None, or an integer" ));
370+ }
371+
372+
373+ // static mp_obj_t linalg_norm(mp_obj_t x) {
358374 mp_float_t dot = 0.0 , value ;
359375 size_t count = 1 ;
360376
@@ -370,33 +386,71 @@ static mp_obj_t linalg_norm(mp_obj_t x) {
370386 return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
371387 } else if (MP_OBJ_IS_TYPE (x , & ulab_ndarray_type )) {
372388 ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (x );
373- if ((ndarray -> ndim != 1 ) && (ndarray -> ndim != 2 )) {
374- mp_raise_ValueError (translate ("norm is defined for 1D and 2D arrays" ));
375- }
376389 uint8_t * array = (uint8_t * )ndarray -> array ;
377-
390+ // always get a float, so that we don't have to resolve the dtype later
378391 mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
392+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
393+ mp_float_t * rarray = NULL ;
394+ ndarray_obj_t * results ;
395+ if (axis != mp_const_none ) {
396+ results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_FLOAT );
397+ rarray = results -> array ;
398+ }
379399
380- size_t k = 0 ;
400+ #if ULAB_MAX_DIMS > 3
401+ size_t i = 0 ;
381402 do {
382- size_t l = 0 ;
403+ #endif
404+ #if ULAB_MAX_DIMS > 2
405+ size_t j = 0 ;
383406 do {
384- value = func (array );
385- dot = dot + (value * value - dot ) / count ++ ;
386- array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
387- l ++ ;
388- } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
389- array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
390- array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
391- k ++ ;
392- } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
393- return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
394- } else {
395- mp_raise_TypeError (translate ("argument must be an interable or ndarray" ));
407+ #endif
408+ #if ULAB_MAX_DIMS > 1
409+ size_t k = 0 ;
410+ do {
411+ #endif
412+ size_t l = 0 ;
413+ if (axis != mp_const_none ) {
414+ count = 1 ;
415+ dot = 0.0 ;
416+ }
417+ do {
418+ value = func (array );
419+ dot = dot + (value * value - dot ) / count ++ ;
420+ array += _shape_strides .strides [ULAB_MAX_DIMS - 1 ];
421+ l ++ ;
422+ } while (l < _shape_strides .shape [ULAB_MAX_DIMS - 1 ]);
423+ if (axis != mp_const_none ) {
424+ * rarray ++ = MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 ));
425+ }
426+ #if ULAB_MAX_DIMS > 1
427+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
428+ array += _shape_strides .strides [ULAB_MAX_DIMS - 2 ];
429+ k ++ ;
430+ } while (k < _shape_strides .shape [ULAB_MAX_DIMS - 2 ]);
431+ #endif
432+ #if ULAB_MAX_DIMS > 2
433+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
434+ array += _shape_strides .strides [ULAB_MAX_DIMS - 3 ];
435+ j ++ ;
436+ } while (j < _shape_strides .shape [ULAB_MAX_DIMS - 3 ]);
437+ #endif
438+ #if ULAB_MAX_DIMS > 3
439+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 3 ] * _shape_strides .shape [ULAB_MAX_DIMS - 3 ];
440+ array += _shape_strides .strides [ULAB_MAX_DIMS - 4 ];
441+ i ++ ;
442+ } while (i < _shape_strides .shape [ULAB_MAX_DIMS - 4 ]);
443+ #endif
444+ if (axis == mp_const_none ) {
445+ return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
446+ }
447+ return results ;
396448 }
449+ return mp_const_none ; // we should never reach this point
397450}
398451
399- MP_DEFINE_CONST_FUN_OBJ_1 (linalg_norm_obj , linalg_norm );
452+ MP_DEFINE_CONST_FUN_OBJ_KW (linalg_norm_obj , 1 , linalg_norm );
453+ // MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
400454
401455#if ULAB_MAX_DIMS > 1
402456#if ULAB_LINALG_HAS_TRACE
0 commit comments