From 0220c42e22974cd6c578e5ad3ae7a094ceb011ec Mon Sep 17 00:00:00 2001 From: s-ol Date: Fri, 8 Dec 2023 20:55:08 +0100 Subject: [PATCH] implement integer array indexing --- code/ndarray.c | 118 ++++++++++++++++++++++-- code/ndarray.h | 92 ++++++++++++++++++ tests/1d/numpy/advanced_indexing.py | 21 +++++ tests/1d/numpy/advanced_indexing.py.exp | 6 ++ tests/2d/numpy/advanced_indexing.py | 10 ++ tests/2d/numpy/advanced_indexing.py.exp | 4 + 6 files changed, 242 insertions(+), 9 deletions(-) create mode 100644 tests/1d/numpy/advanced_indexing.py create mode 100644 tests/1d/numpy/advanced_indexing.py.exp create mode 100644 tests/2d/numpy/advanced_indexing.py create mode 100644 tests/2d/numpy/advanced_indexing.py.exp diff --git a/code/ndarray.c b/code/ndarray.c index af881478..c41c01a2 100644 --- a/code/ndarray.c +++ b/code/ndarray.c @@ -1197,6 +1197,94 @@ static mp_obj_t ndarray_from_boolean_index(ndarray_obj_t *ndarray, ndarray_obj_t return MP_OBJ_FROM_PTR(results); } + +static mp_obj_t ndarray_from_integer_index(ndarray_obj_t *ndarray, ndarray_obj_t *index) { + if(ndarray->ndim > 1) { + mp_raise_ValueError(MP_ERROR_TEXT("only supports 1-dim target arrays")); + } + + if(!ndarray_is_dense(ndarray)) { + mp_raise_ValueError(MP_ERROR_TEXT("only supports dense target arrays")); + } + + // TODO: range-check index values against ndarray->shape[ULAB_MAX_DIMS-1] + // TODO: normalize or handle negative indices in loop (without modifying index) + + int32_t *strides = strides_from_shape(index->shape, ndarray->dtype); + ndarray_obj_t *results = ndarray_new_ndarray(index->ndim, index->shape, strides, ndarray->dtype); + + uint8_t *larray = (uint8_t *)ndarray->array; + uint8_t *iarray = (uint8_t *)index->array; + + if (ndarray->dtype == NDARRAY_UINT8) { + if (index->dtype == NDARRAY_UINT8) { + INDEX_LOOP(results, uint8_t, uint8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT8) { + INDEX_LOOP(results, uint8_t, int8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_UINT16) { + INDEX_LOOP(results, uint8_t, uint16_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT16) { + INDEX_LOOP(results, uint8_t, int16_t, larray, ndarray->strides, iarray, index->strides); + } + } else if (ndarray->dtype == NDARRAY_INT8) { + if (index->dtype == NDARRAY_UINT8) { + INDEX_LOOP(results, int8_t, uint8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT8) { + INDEX_LOOP(results, int8_t, int8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_UINT16) { + INDEX_LOOP(results, int8_t, uint16_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT16) { + INDEX_LOOP(results, int8_t, int16_t, larray, ndarray->strides, iarray, index->strides); + } + } else if (ndarray->dtype == NDARRAY_UINT16) { + if (index->dtype == NDARRAY_UINT8) { + INDEX_LOOP(results, uint16_t, uint8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT8) { + INDEX_LOOP(results, uint16_t, int8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_UINT16) { + INDEX_LOOP(results, uint16_t, uint16_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT16) { + INDEX_LOOP(results, uint16_t, int16_t, larray, ndarray->strides, iarray, index->strides); + } + } else if (ndarray->dtype == NDARRAY_INT16) { + if (index->dtype == NDARRAY_UINT8) { + INDEX_LOOP(results, int16_t, uint8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT8) { + INDEX_LOOP(results, int16_t, int8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_UINT16) { + INDEX_LOOP(results, int16_t, uint16_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT16) { + INDEX_LOOP(results, int16_t, int16_t, larray, ndarray->strides, iarray, index->strides); + } + } else if (ndarray->dtype == NDARRAY_FLOAT) { + if (index->dtype == NDARRAY_UINT8) { + INDEX_LOOP(results, mp_float_t, uint8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT8) { + INDEX_LOOP(results, mp_float_t, int8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_UINT16) { + INDEX_LOOP(results, mp_float_t, uint16_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT16) { + INDEX_LOOP(results, mp_float_t, int16_t, larray, ndarray->strides, iarray, index->strides); + } + #if ULAB_SUPPORTS_COMPLEX + } else if (ndarray->dtype == NDARRAY_COMPLEX) { + struct complex_t { float a; float b; }; + + if (index->dtype == NDARRAY_UINT8) { + INDEX_LOOP(results, struct complex_t, uint8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT8) { + INDEX_LOOP(results, struct complex_t, int8_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_UINT16) { + INDEX_LOOP(results, struct complex_t, uint16_t, larray, ndarray->strides, iarray, index->strides); + } else if (index->dtype == NDARRAY_INT16) { + INDEX_LOOP(results, struct complex_t, int16_t, larray, ndarray->strides, iarray, index->strides); + } + #endif + } + + return MP_OBJ_FROM_PTR(results); +} + static mp_obj_t ndarray_assign_from_boolean_index(ndarray_obj_t *ndarray, ndarray_obj_t *index, ndarray_obj_t *values) { // assigns values to a Boolean-indexed array // first we have to find out how many trues there are @@ -1313,16 +1401,28 @@ static mp_obj_t ndarray_assign_from_boolean_index(ndarray_obj_t *ndarray, ndarra static mp_obj_t ndarray_get_slice(ndarray_obj_t *ndarray, mp_obj_t index, ndarray_obj_t *values) { if(mp_obj_is_type(index, &ulab_ndarray_type)) { ndarray_obj_t *nindex = MP_OBJ_TO_PTR(index); - if((nindex->ndim > 1) || (nindex->boolean == false)) { - mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is implemented for 1D Boolean arrays only")); - } - if(values == NULL) { // return value(s) - return ndarray_from_boolean_index(ndarray, nindex); - } else { // assign value(s) - ndarray_assign_from_boolean_index(ndarray, nindex, values); + + if(nindex->boolean) { + if(nindex->ndim > 1) { + mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is only implemented for integer arrays or 1D Boolean arrays")); + } + + if(values == NULL) { // return value(s) + return ndarray_from_boolean_index(ndarray, index); + } else { // assign value(s) + ndarray_assign_from_boolean_index(ndarray, nindex, values); + } + } else if ((nindex->dtype == NDARRAY_UINT8) || (nindex->dtype == NDARRAY_INT8) || + (nindex->dtype == NDARRAY_UINT16) || (nindex->dtype == NDARRAY_INT16)) { + if(values == NULL) { // return value(s) + return ndarray_from_integer_index(ndarray, nindex); + } else { // assign value(s) + mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is only implemented for integer arrays or 1D Boolean arrays")); + } + } else { + mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is only implemented for integer arrays or 1D Boolean arrays")); } - } - if(mp_obj_is_type(index, &mp_type_tuple) || mp_obj_is_int(index) || mp_obj_is_type(index, &mp_type_slice)) { + } else if(mp_obj_is_type(index, &mp_type_tuple) || mp_obj_is_int(index) || mp_obj_is_type(index, &mp_type_slice)) { mp_obj_tuple_t *tuple; if(mp_obj_is_type(index, &mp_type_tuple)) { tuple = MP_OBJ_TO_PTR(index); diff --git a/code/ndarray.h b/code/ndarray.h index ec8b3ee7..742be526 100644 --- a/code/ndarray.h +++ b/code/ndarray.h @@ -391,6 +391,17 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t ); l++;\ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\ +#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\ + uint8_t *array = (uint8_t *)results->array;\ + size_t l = 0;\ + do {\ + size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\ + *((type_left *)array) = *((type_left *)(larray + offset));\ + array += results->strides[ULAB_MAX_DIMS - 1];\ + iarray += istrides[ULAB_MAX_DIMS - 1];\ + l++;\ + } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\ + #endif /* ULAB_MAX_DIMS == 1 */ #if ULAB_MAX_DIMS == 2 @@ -464,6 +475,25 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t ); k++;\ } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\ +#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\ + uint8_t *array = (uint8_t *)results->array;\ + size_t k = 0;\ + do {\ + size_t l = 0;\ + do {\ + size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\ + *((type_left *)array) = *((type_left *)(larray + offset));\ + array += results->strides[ULAB_MAX_DIMS - 1];\ + iarray += istrides[ULAB_MAX_DIMS - 1];\ + l++;\ + } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\ + (array) -= (results->strides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\ + (array) += (results->strides)[ULAB_MAX_DIMS - 2];\ + (iarray) -= (istrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\ + (iarray) += (istrides)[ULAB_MAX_DIMS - 2];\ + k++;\ + } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\ + #endif /* ULAB_MAX_DIMS == 2 */ #if ULAB_MAX_DIMS == 3 @@ -569,6 +599,33 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t ); j++;\ } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\ +#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\ + uint8_t *array = (uint8_t *)results->array;\ + size_t j = 0;\ + do {\ + size_t k = 0;\ + do {\ + size_t l = 0;\ + do {\ + size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\ + *((type_left *)array) = *((type_left *)(larray + offset));\ + array += results->strides[ULAB_MAX_DIMS - 1];\ + iarray += istrides[ULAB_MAX_DIMS - 1];\ + l++;\ + } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\ + (array) -= (results->strides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\ + (array) += (results->strides)[ULAB_MAX_DIMS - 2];\ + (iarray) -= (istrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\ + (iarray) += (istrides)[ULAB_MAX_DIMS - 2];\ + k++;\ + } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\ + (array) -= (results->strides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\ + (array) += (results->strides)[ULAB_MAX_DIMS - 3];\ + (iarray) -= (istrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\ + (iarray) += (istrides)[ULAB_MAX_DIMS - 3];\ + j++;\ + } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\ + #endif /* ULAB_MAX_DIMS == 3 */ #if ULAB_MAX_DIMS == 4 @@ -706,6 +763,41 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t ); i++;\ } while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\ +#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\ + uint8_t *array = (uint8_t *)results->array;\ + size_t i = 0;\ + do {\ + size_t j = 0;\ + do {\ + size_t k = 0;\ + do {\ + size_t l = 0;\ + do {\ + size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\ + *((type_left *)array) = *((type_left *)(larray + offset));\ + array += results->strides[ULAB_MAX_DIMS - 1];\ + iarray += istrides[ULAB_MAX_DIMS - 1];\ + l++;\ + } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\ + (array) -= (results->strides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\ + (array) += (results->strides)[ULAB_MAX_DIMS - 2];\ + (iarray) -= (istrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\ + (iarray) += (istrides)[ULAB_MAX_DIMS - 2];\ + k++;\ + } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\ + (array) -= (results->strides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\ + (array) += (results->strides)[ULAB_MAX_DIMS - 3];\ + (iarray) -= (istrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\ + (iarray) += (istrides)[ULAB_MAX_DIMS - 3];\ + j++;\ + } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\ + (array) -= (results->strides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\ + (array) += (results->strides)[ULAB_MAX_DIMS - 4];\ + (iarray) -= (lstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\ + (iarray) += (lstrides)[ULAB_MAX_DIMS - 4];\ + i++;\ + } while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\ + #endif /* ULAB_MAX_DIMS == 4 */ #endif /* ULAB_HAS_FUNCTION_ITERATOR */ diff --git a/tests/1d/numpy/advanced_indexing.py b/tests/1d/numpy/advanced_indexing.py new file mode 100644 index 00000000..c20684ba --- /dev/null +++ b/tests/1d/numpy/advanced_indexing.py @@ -0,0 +1,21 @@ +from ulab import numpy as np + +a = np.array(range(0, 100, 10), dtype=np.uint8) +b = np.array([0.5, 1.5, 0.2, 4.3], dtype=np.float) + +# integer array indexing +print(a[np.array([0, 4, 2], dtype=np.uint8)]) +print(b[np.array([3, 2, 2, 3], dtype=np.int16)]) +# TODO: test negative indices +# TODO: check range checking + +# boolean array indexing +print(a[a >= 50]) +print(b[b > 1]) + +# boolean array index assignment +a[a > 1] = 0 +print(a) + +b[b > 50] += 5 +print(b) diff --git a/tests/1d/numpy/advanced_indexing.py.exp b/tests/1d/numpy/advanced_indexing.py.exp new file mode 100644 index 00000000..ee5f94f8 --- /dev/null +++ b/tests/1d/numpy/advanced_indexing.py.exp @@ -0,0 +1,6 @@ +array([0, 40, 20], dtype=uint8) +array([4.3, 0.2, 0.2, 4.3], dtype=float64) +array([50, 60, 70, 80, 90], dtype=uint8) +array([1.5, 4.3], dtype=float64) +array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=uint8) +array([0.5, 1.5, 0.2, 4.3], dtype=float64) diff --git a/tests/2d/numpy/advanced_indexing.py b/tests/2d/numpy/advanced_indexing.py new file mode 100644 index 00000000..2be36359 --- /dev/null +++ b/tests/2d/numpy/advanced_indexing.py @@ -0,0 +1,10 @@ +from ulab import numpy as np + +a = np.array(range(0, 100, 10), dtype=np.uint8) +b = np.array([0.5, 1.5, 0.2, 4.3], dtype=np.float) + +# integer array indexing +print(a[np.array([[0, 4], [1, 2]], dtype=np.uint8)]) +print(b[np.array([[3, 2], [2, 3]], dtype=np.uint8)]) +# TODO: test negative indices +# TODO: check range checking diff --git a/tests/2d/numpy/advanced_indexing.py.exp b/tests/2d/numpy/advanced_indexing.py.exp new file mode 100644 index 00000000..ebeb5729 --- /dev/null +++ b/tests/2d/numpy/advanced_indexing.py.exp @@ -0,0 +1,4 @@ +array([[0, 40], + [10, 20]], dtype=uint8) +array([[4.3, 0.2], + [0.2, 4.3]], dtype=float64)