@@ -63,38 +63,6 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
6363 }
6464}
6565
66- static shape_strides numerical_reduce_axes_ (ndarray_obj_t * ndarray , mp_obj_t axis ) {
67- // TODO: replace numerical_reduce_axes with this function, wherever applicable
68- int8_t ax = mp_obj_get_int (axis );
69- if (ax < 0 ) ax += ndarray -> ndim ;
70- if ((ax < 0 ) || (ax > ndarray -> ndim - 1 )) {
71- mp_raise_ValueError (translate ("index out of range" ));
72- }
73- shape_strides _shape_strides ;
74- _shape_strides .index = ULAB_MAX_DIMS - ndarray -> ndim + ax ;
75- size_t * shape = m_new (size_t , ULAB_MAX_DIMS );
76- memset (shape , 0 , sizeof (size_t )* ULAB_MAX_DIMS );
77- _shape_strides .shape = shape ;
78- int32_t * strides = m_new (int32_t , ULAB_MAX_DIMS );
79- memset (strides , 0 , sizeof (uint32_t )* ULAB_MAX_DIMS );
80- _shape_strides .strides = strides ;
81- if ((ndarray -> ndim == 1 ) && (_shape_strides .axis == 0 )) {
82- _shape_strides .index = 0 ;
83- _shape_strides .shape [ULAB_MAX_DIMS - 1 ] = 1 ;
84- } else {
85- for (uint8_t i = ULAB_MAX_DIMS - 1 ; i > 0 ; i -- ) {
86- if (i > _shape_strides .index ) {
87- _shape_strides .shape [i ] = ndarray -> shape [i ];
88- _shape_strides .strides [i ] = ndarray -> strides [i ];
89- } else {
90- _shape_strides .shape [i ] = ndarray -> shape [i - 1 ];
91- _shape_strides .strides [i ] = ndarray -> strides [i - 1 ];
92- }
93- }
94- }
95- return _shape_strides ;
96- }
97-
9866#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
9967static mp_obj_t numerical_all_any (mp_obj_t oin , mp_obj_t axis , uint8_t optype ) {
10068 bool anytype = optype == NUMERICAL_ALL ? 1 : 0 ;
@@ -130,25 +98,25 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
13098 l ++ ;
13199 } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
132100 #if ULAB_MAX_DIMS > 1
133- array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
101+ array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
134102 array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
135103 k ++ ;
136104 } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
137105 #endif
138106 #if ULAB_MAX_DIMS > 2
139- array -= ndarray -> strides [ULAB_MAX_DIMS - 2 ] * ndarray -> shape [ULAB_MAX_DIMS - 2 ];
107+ array -= ndarray -> strides [ULAB_MAX_DIMS - 2 ] * ndarray -> shape [ULAB_MAX_DIMS - 2 ];
140108 array += ndarray -> strides [ULAB_MAX_DIMS - 3 ];
141109 j ++ ;
142110 } while (j < ndarray -> shape [ULAB_MAX_DIMS - 3 ]);
143111 #endif
144112 #if ULAB_MAX_DIMS > 3
145- array -= ndarray -> strides [ULAB_MAX_DIMS - 3 ] * ndarray -> shape [ULAB_MAX_DIMS - 3 ];
113+ array -= ndarray -> strides [ULAB_MAX_DIMS - 3 ] * ndarray -> shape [ULAB_MAX_DIMS - 3 ];
146114 array += ndarray -> strides [ULAB_MAX_DIMS - 4 ];
147115 i ++ ;
148116 } while (i < ndarray -> shape [ULAB_MAX_DIMS - 4 ]);
149117 #endif
150118 } else {
151- shape_strides _shape_strides = numerical_reduce_axes_ (ndarray , axis );
119+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
152120 ndarray_obj_t * results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_BOOL );
153121 uint8_t * rarray = (uint8_t * )results -> array ;
154122 if (optype == NUMERICAL_ALL ) {
@@ -173,33 +141,33 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
173141 // optype == NUMERICAL_ANY
174142 * rarray = 1 ;
175143 // since we are breaking out of the loop, move the pointer forward
176- array += ndarray -> strides [ _shape_strides .index ] * (ndarray -> shape [ _shape_strides .index ] - l );
144+ array += _shape_strides .strides [ 0 ] * (_shape_strides .shape [ 0 ] - l );
177145 break ;
178146 } else if ((value == MICROPY_FLOAT_CONST (0.0 )) & anytype ) {
179147 // optype == NUMERICAL_ALL
180148 * rarray = 0 ;
181149 // since we are breaking out of the loop, move the pointer forward
182- array += ndarray -> strides [ _shape_strides .index ] * (ndarray -> shape [ _shape_strides .index ] - l );
150+ array += _shape_strides .strides [ 0 ] * (_shape_strides .shape [ 0 ] - l );
183151 break ;
184152 }
185- array += ndarray -> strides [ _shape_strides .index ];
153+ array += _shape_strides .strides [ 0 ];
186154 l ++ ;
187- } while (l < ndarray -> shape [ _shape_strides .index ]);
155+ } while (l < _shape_strides .shape [ 0 ]);
188156 #if ULAB_MAX_DIMS > 1
189157 rarray ++ ;
190- array -= ndarray -> strides [ _shape_strides .index ] * ndarray -> shape [ _shape_strides .index ];
158+ array -= _shape_strides .strides [ 0 ] * _shape_strides .shape [ 0 ];
191159 array += _shape_strides .strides [ULAB_MAX_DIMS - 1 ];
192160 k ++ ;
193161 } while (k < _shape_strides .shape [ULAB_MAX_DIMS - 1 ]);
194162 #endif
195163 #if ULAB_MAX_DIMS > 2
196- array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
164+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
197165 array += _shape_strides .strides [ULAB_MAX_DIMS - 2 ];
198166 j ++ ;
199167 } while (j < _shape_strides .shape [ULAB_MAX_DIMS - 2 ]);
200168 #endif
201169 #if ULAB_MAX_DIMS > 3
202- array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
170+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
203171 array += _shape_strides .strides [ULAB_MAX_DIMS - 3 ];
204172 i ++ ;
205173 } while (i < _shape_strides .shape [ULAB_MAX_DIMS - 3 ])
0 commit comments