@@ -354,33 +354,46 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354354//| ...
355355//|
356356
357- static mp_obj_t linalg_norm (mp_obj_t _x ) {
358- if (!MP_OBJ_IS_TYPE (_x , & ulab_ndarray_type )) {
359- mp_raise_TypeError (translate ("argument must be ndarray" ));
360- }
361- ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (_x );
362- if ((ndarray -> ndim != 1 ) && (ndarray -> ndim != 2 )) {
363- mp_raise_ValueError (translate ("norm is defined for 1D and 2D arrays" ));
364- }
365- mp_float_t dot = 0.0 ;
366- uint8_t * array = (uint8_t * )ndarray -> array ;
357+ static mp_obj_t linalg_norm (mp_obj_t x ) {
358+ mp_float_t dot = 0.0 , value ;
359+ size_t count = 1 ;
360+
361+ if (MP_OBJ_IS_TYPE (x , & mp_type_tuple ) || MP_OBJ_IS_TYPE (x , & mp_type_list ) || MP_OBJ_IS_TYPE (x , & mp_type_range )) {
362+ mp_obj_iter_buf_t iter_buf ;
363+ mp_obj_t item , iterable = mp_getiter (x , & iter_buf );
364+ while ((item = mp_iternext (iterable )) != MP_OBJ_STOP_ITERATION ) {
365+ value = mp_obj_get_float (item );
366+ // we could simply take the sum of value ** 2,
367+ // but this method is numerically stable
368+ dot = dot + (value * value - dot ) / count ++ ;
369+ }
370+ return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
371+ } else if (MP_OBJ_IS_TYPE (x , & ulab_ndarray_type )) {
372+ ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (x );
373+ if ((ndarray -> ndim != 1 ) && (ndarray -> ndim != 2 )) {
374+ mp_raise_ValueError (translate ("norm is defined for 1D and 2D arrays" ));
375+ }
376+ uint8_t * array = (uint8_t * )ndarray -> array ;
367377
368- mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
378+ mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
369379
370- size_t k = 0 ;
371- do {
372- size_t l = 0 ;
380+ size_t k = 0 ;
373381 do {
374- mp_float_t v = func (array );
375- array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
376- dot += v * v ;
377- l ++ ;
378- } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
379- array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
380- array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
381- k ++ ;
382- } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
383- return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot ));
382+ size_t l = 0 ;
383+ do {
384+ value = func (array );
385+ dot = dot + (value * value - dot ) / count ++ ;
386+ array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
387+ l ++ ;
388+ } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
389+ array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
390+ array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
391+ k ++ ;
392+ } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
393+ return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
394+ } else {
395+ mp_raise_TypeError (translate ("argument must be an interable or ndarray" ));
396+ }
384397}
385398
386399MP_DEFINE_CONST_FUN_OBJ_1 (linalg_norm_obj , linalg_norm );
0 commit comments