Skip to content

Commit 6bc45e4

Browse files
authored
Merge pull request #41 from stackhpc/negative-selection
Support negative values in selection
2 parents a94ecdc + 6cf7185 commit 6bc45e4

File tree

2 files changed

+282
-39
lines changed

2 files changed

+282
-39
lines changed

src/array.rs

Lines changed: 221 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,29 +59,61 @@ fn build_array_from_shape<T>(
5959
ArrayView::<T, _>::from_shape(shape, data).map_err(ActiveStorageError::ShapeInvalid)
6060
}
6161

62-
/// Returns an optional [ndarray] SliceInfo object corresponding to the selection.
62+
/// Returns an array index in numpy semantics to an index with ndarray semantics.
63+
///
64+
/// The resulting value will be clamped such that it is safe for indexing in ndarray.
65+
/// This allows us to accept selections with NumPy's less restrictive semantics.
66+
/// When the stride is negative (`reverse` is `true`), the result is offset by one to allow for
67+
/// Numpy's non-inclusive start and inclusive end in this scenario.
68+
///
69+
/// # Arguments
70+
///
71+
/// * `index`: Selection index
72+
/// * `length`: Length of corresponding axis
73+
/// * `reverse`: Whether the stride is negative
74+
fn to_ndarray_index(index: isize, length: usize, reverse: bool) -> isize {
75+
let length_isize = length.try_into().expect("Length too large!");
76+
let result = if reverse { index + 1 } else { index };
77+
if index < 0 {
78+
std::cmp::max(result + length_isize, 0)
79+
} else {
80+
std::cmp::min(result, length_isize)
81+
}
82+
}
83+
84+
/// Convert a [crate::models::Slice] object with indices in numpy semantics to an
85+
/// [ndarray::SliceInfoElem::Slice] with ndarray semantics.
86+
///
87+
/// See [ndarray docs](https://docs.rs/ndarray/0.15.6/ndarray/macro.s.html#negative-step) for
88+
/// information about ndarray's handling of negative strides.
89+
fn to_ndarray_slice(slice: &models::Slice, length: usize) -> ndarray::SliceInfoElem {
90+
let reverse = slice.stride < 0;
91+
let start = to_ndarray_index(slice.start, length, reverse);
92+
let end = to_ndarray_index(slice.end, length, reverse);
93+
let (start, end) = if reverse { (end, start) } else { (start, end) };
94+
ndarray::SliceInfoElem::Slice {
95+
start,
96+
end: Some(end),
97+
step: slice.stride,
98+
}
99+
}
100+
101+
/// Returns an [ndarray] SliceInfo object corresponding to the selection.
63102
pub fn build_slice_info<T>(
64103
selection: &Option<Vec<models::Slice>>,
65104
shape: &[usize],
66105
) -> ndarray::SliceInfo<Vec<ndarray::SliceInfoElem>, ndarray::IxDyn, ndarray::IxDyn> {
67106
match selection {
68107
Some(selection) => {
69-
let si: Vec<ndarray::SliceInfoElem> = selection
70-
.iter()
71-
.map(|slice| ndarray::SliceInfoElem::Slice {
72-
// FIXME: usize should be isize?
73-
start: slice.start as isize,
74-
end: Some(slice.end as isize),
75-
step: slice.stride as isize,
76-
})
108+
let si: Vec<ndarray::SliceInfoElem> = std::iter::zip(selection, shape)
109+
.map(|(slice, length)| to_ndarray_slice(slice, *length))
77110
.collect();
78111
ndarray::SliceInfo::try_from(si).expect("SliceInfo should not fail for IxDyn")
79112
}
80113
_ => {
81114
let si: Vec<ndarray::SliceInfoElem> = shape
82115
.iter()
83116
.map(|_| ndarray::SliceInfoElem::Slice {
84-
// FIXME: usize should be isize?
85117
start: 0,
86118
end: None,
87119
step: 1,
@@ -309,7 +341,7 @@ mod tests {
309341
#[test]
310342
fn build_slice_info_1d_selection() {
311343
let selection = Some(vec![models::Slice::new(0, 1, 1)]);
312-
let shape = [];
344+
let shape = [1];
313345
let slice_info = build_slice_info::<u32>(&selection, &shape);
314346
assert_eq!(
315347
[ndarray::SliceInfoElem::Slice {
@@ -321,6 +353,51 @@ mod tests {
321353
);
322354
}
323355

356+
#[test]
357+
fn build_slice_info_1d_selection_negative_stride() {
358+
let selection = Some(vec![models::Slice::new(1, 0, -1)]);
359+
let shape = [1];
360+
let slice_info = build_slice_info::<u32>(&selection, &shape);
361+
assert_eq!(
362+
[ndarray::SliceInfoElem::Slice {
363+
start: 1,
364+
end: Some(1),
365+
step: -1
366+
}],
367+
slice_info.as_ref()
368+
);
369+
}
370+
371+
#[test]
372+
fn build_slice_info_1d_selection_negative_start() {
373+
let selection = Some(vec![models::Slice::new(-1, 1, 1)]);
374+
let shape = [1];
375+
let slice_info = build_slice_info::<u32>(&selection, &shape);
376+
assert_eq!(
377+
[ndarray::SliceInfoElem::Slice {
378+
start: 0,
379+
end: Some(1),
380+
step: 1
381+
}],
382+
slice_info.as_ref()
383+
);
384+
}
385+
386+
#[test]
387+
fn build_slice_info_1d_selection_negative_end() {
388+
let selection = Some(vec![models::Slice::new(0, -1, 1)]);
389+
let shape = [1];
390+
let slice_info = build_slice_info::<u32>(&selection, &shape);
391+
assert_eq!(
392+
[ndarray::SliceInfoElem::Slice {
393+
start: 0,
394+
end: Some(0),
395+
step: 1
396+
}],
397+
slice_info.as_ref()
398+
);
399+
}
400+
324401
#[test]
325402
fn build_slice_info_2d_no_selection() {
326403
let selection = None;
@@ -349,7 +426,7 @@ mod tests {
349426
models::Slice::new(0, 1, 1),
350427
models::Slice::new(0, 1, 1),
351428
]);
352-
let shape = [];
429+
let shape = [1, 1];
353430
let slice_info = build_slice_info::<u32>(&selection, &shape);
354431
assert_eq!(
355432
[
@@ -405,4 +482,136 @@ mod tests {
405482
let array = build_array::<i64>(&request_data, &bytes).unwrap();
406483
assert_eq!(array![[0x04030201_i64], [0x08070605_i64]].into_dyn(), array);
407484
}
485+
486+
// Helper function for tests that slice an array using a selection.
487+
fn test_selection(slice: models::Slice, expected: Array1<u32>) {
488+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
489+
let request_data = models::RequestData {
490+
source: Url::parse("http://example.com").unwrap(),
491+
bucket: "bar".to_string(),
492+
object: "baz".to_string(),
493+
dtype: models::DType::Uint32,
494+
offset: None,
495+
size: None,
496+
shape: None,
497+
order: None,
498+
selection: None,
499+
};
500+
let bytes = Bytes::copy_from_slice(&data);
501+
let array = build_array::<u32>(&request_data, &bytes).unwrap();
502+
let shape = vec![2];
503+
let slice_info = build_slice_info::<u32>(&Some(vec![slice]), &shape);
504+
let sliced = array.slice(slice_info);
505+
assert_eq!(sliced, expected.into_dyn().view());
506+
}
507+
508+
#[test]
509+
fn build_array_with_selection_all() {
510+
test_selection(
511+
models::Slice::new(0, 2, 1),
512+
array![0x04030201_u32, 0x08070605_u32],
513+
)
514+
}
515+
516+
#[test]
517+
fn build_array_with_selection_negative_start() {
518+
test_selection(
519+
models::Slice::new(-2, 2, 1),
520+
array![0x04030201_u32, 0x08070605_u32],
521+
)
522+
}
523+
524+
#[test]
525+
fn build_array_with_selection_start_lt_negative_length() {
526+
test_selection(
527+
models::Slice::new(-3, 2, 1),
528+
array![0x04030201_u32, 0x08070605_u32],
529+
)
530+
}
531+
532+
#[test]
533+
fn build_array_with_selection_start_eq_length() {
534+
test_selection(models::Slice::new(2, 2, 1), array![])
535+
}
536+
537+
#[test]
538+
fn build_array_with_selection_start_gt_length() {
539+
test_selection(models::Slice::new(3, 2, 1), array![])
540+
}
541+
542+
#[test]
543+
fn build_array_with_selection_negative_end() {
544+
test_selection(models::Slice::new(0, -1, 1), array![0x04030201_u32])
545+
}
546+
547+
#[test]
548+
fn build_array_with_selection_end_lt_negative_length() {
549+
test_selection(models::Slice::new(0, -3, 1), array![])
550+
}
551+
552+
#[test]
553+
fn build_array_with_selection_end_gt_length() {
554+
test_selection(
555+
models::Slice::new(0, 3, 1),
556+
array![0x04030201_u32, 0x08070605_u32],
557+
)
558+
}
559+
560+
#[test]
561+
fn build_array_with_selection_all_negative_stride() {
562+
// Need to end at -3 to read first item.
563+
// translates to [0, 2]
564+
test_selection(
565+
models::Slice::new(1, -3, -1),
566+
array![0x08070605_u32, 0x04030201_u32],
567+
)
568+
}
569+
570+
#[test]
571+
fn build_array_with_selection_negative_start_negative_stride() {
572+
// translates to [0, 2]
573+
test_selection(
574+
models::Slice::new(-1, -3, -1),
575+
array![0x08070605_u32, 0x04030201_u32],
576+
)
577+
}
578+
579+
#[test]
580+
fn build_array_with_selection_start_lt_negative_length_negative_stride() {
581+
// translates to [1, 0]
582+
test_selection(models::Slice::new(-3, 0, -1), array![])
583+
}
584+
585+
#[test]
586+
fn build_array_with_selection_start_eq_length_negative_stride() {
587+
// translates to [2, 2]
588+
test_selection(models::Slice::new(2, 1, -1), array![])
589+
}
590+
591+
#[test]
592+
fn build_array_with_selection_start_gt_length_negative_stride() {
593+
// translates to [2, 2]
594+
test_selection(models::Slice::new(3, 1, -1), array![])
595+
}
596+
597+
#[test]
598+
fn build_array_with_selection_negative_end_negative_stride() {
599+
// translates to [2, 2]
600+
test_selection(models::Slice::new(2, -1, -1), array![])
601+
}
602+
603+
#[test]
604+
fn build_array_with_selection_end_lt_negative_length_negative_stride() {
605+
// translates to [0, 2]
606+
test_selection(
607+
models::Slice::new(1, -3, -1),
608+
array![0x08070605_u32, 0x04030201_u32],
609+
)
610+
}
611+
612+
#[test]
613+
fn build_array_with_selection_end_gt_length_negative_stride() {
614+
// translates to [1, 2]
615+
test_selection(models::Slice::new(3, 0, -1), array![0x08070605_u32])
616+
}
408617
}

0 commit comments

Comments
 (0)