Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,27 +484,66 @@ Comparator<CType>* GetComparator(CompareOperator op) {
}

template <typename T, typename Fn, typename CType = typename TypeTraits<T>::CType>
std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, Fn&& fn) {
std::shared_ptr<Array> CompareAndFilter(const std::shared_ptr<Array>& array, Fn&& fn) {
using ArrayType = typename TypeTraits<T>::ArrayType;
auto typed_array = checked_pointer_cast<ArrayType>(array);

std::vector<CType> filtered;
filtered.reserve(length);
std::copy_if(data, data + length, std::back_inserter(filtered), std::forward<Fn>(fn));
filtered.reserve(array->length());

for (int64_t i = 0; i < array->length(); ++i) {
if (array->IsNull(i)) {
// Nulls are filtered out (comparison with null is false)
continue;
}
CType value = typed_array->Value(i);
if (fn(value)) {
filtered.push_back(value);
}
}

std::shared_ptr<Array> filtered_array;
ArrayFromVector<T, CType>(filtered, &filtered_array);
return filtered_array;
}

template <typename T, typename CType = typename TypeTraits<T>::CType>
std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, CType val,
template <typename T, typename U,
typename = std::enable_if_t<std::is_same_v<U, typename TypeTraits<T>::CType>>>
std::shared_ptr<Array> CompareAndFilter(const std::shared_ptr<Array>& array, U val,
CompareOperator op) {
using CType = typename TypeTraits<T>::CType;
auto cmp = GetComparator<CType>(op);
return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, val); });
return CompareAndFilter<T>(array, [&](CType e) { return cmp(e, val); });
}

template <typename T, typename CType = typename TypeTraits<T>::CType>
std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length,
const CType* other, CompareOperator op) {
template <typename T>
std::shared_ptr<Array> CompareAndFilter(const std::shared_ptr<Array>& lhs,
const std::shared_ptr<Array>& rhs,
CompareOperator op) {
using ArrayType = typename TypeTraits<T>::ArrayType;
using CType = typename TypeTraits<T>::CType;
auto lhs_typed = checked_pointer_cast<ArrayType>(lhs);
auto rhs_typed = checked_pointer_cast<ArrayType>(rhs);
auto cmp = GetComparator<CType>(op);
return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, *other++); });

std::vector<CType> filtered;
filtered.reserve(lhs->length());

for (int64_t i = 0; i < lhs->length(); ++i) {
// Skip if either element is null
if (lhs->IsNull(i) || rhs->IsNull(i)) {
continue;
}
CType lhs_value = lhs_typed->Value(i);
CType rhs_value = rhs_typed->Value(i);
if (cmp(lhs_value, rhs_value)) {
filtered.push_back(lhs_value);
}
}

std::shared_ptr<Array> filtered_array;
ArrayFromVector<T, CType>(filtered, &filtered_array);
return filtered_array;
}

TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
Expand All @@ -515,9 +554,10 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
auto rand = random::RandomArrayGenerator(kRandomSeed);
for (size_t i = 3; i < 10; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
// TODO(bkietz) rewrite with some nulls
auto array =
checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0, 100, 0));
// Use deterministic null probabilities: 0.0, 0.25, 0.4, 0.5, 0.571, 0.625, 0.667
double null_probability = static_cast<double>(i - 3) / i;
auto array = checked_pointer_cast<ArrayType>(
rand.Numeric<TypeParam>(length, 0, 100, null_probability));
CType c_fifty = 50;
auto fifty = std::make_shared<ScalarType>(c_fifty);
for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
Expand All @@ -527,8 +567,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection));
auto filtered_array = filtered.make_array();
ValidateOutput(*filtered_array);
auto expected =
CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op);
auto expected = CompareAndFilter<TypeParam>(array, c_fifty, op);
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
}
}
Expand All @@ -540,18 +579,20 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
auto rand = random::RandomArrayGenerator(kRandomSeed);
for (size_t i = 3; i < 10; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
// Use deterministic null probabilities with different values for lhs and rhs
double null_probability_lhs = static_cast<double>(i - 3) / i;
double null_probability_rhs = static_cast<double>(i) / (i + 7);
auto lhs = checked_pointer_cast<ArrayType>(
rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
rand.Numeric<TypeParam>(length, 0, 100, null_probability_lhs));
auto rhs = checked_pointer_cast<ArrayType>(
rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
rand.Numeric<TypeParam>(length, 0, 100, null_probability_rhs));
for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
ASSERT_OK_AND_ASSIGN(Datum selection,
CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs}));
ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(lhs, selection));
auto filtered_array = filtered.make_array();
ValidateOutput(*filtered_array);
auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(),
rhs->raw_values(), op);
auto expected = CompareAndFilter<TypeParam>(lhs, rhs, op);
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
}
}
Expand All @@ -565,8 +606,10 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
auto rand = random::RandomArrayGenerator(kRandomSeed);
for (size_t i = 3; i < 10; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
// Use deterministic null probabilities: 0.0, 0.25, 0.4, 0.5, 0.571, 0.625, 0.667
double null_probability = static_cast<double>(i - 3) / i;
auto array = checked_pointer_cast<ArrayType>(
rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
rand.Numeric<TypeParam>(length, 0, 100, null_probability));
CType c_fifty = 50, c_hundred = 100;
auto fifty = std::make_shared<ScalarType>(c_fifty);
auto hundred = std::make_shared<ScalarType>(c_hundred);
Expand All @@ -579,8 +622,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
auto filtered_array = filtered.make_array();
ValidateOutput(*filtered_array);
auto expected = CompareAndFilter<TypeParam>(
array->raw_values(), array->length(),
[&](CType e) { return (e > c_fifty) && (e < c_hundred); });
array, [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
}
}
Expand Down
Loading