Skip to content

Commit

Permalink
vector_selection_test: Extract DoCheckTake
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecrv committed Apr 19, 2024
1 parent 4d0856e commit 27ca4ea
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1034,29 +1034,34 @@ Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string& values
.Value(out);
}

void DoCheckTake(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& indices,
const std::shared_ptr<Array>& expected) {
AssertTakeArrays(values, indices, expected);

// Check sliced values
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
AssertTakeArrays(values_sliced, indices, expected);

// Check sliced indices
ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(indices->type(), int8_t{0}));
ASSERT_OK_AND_ASSIGN(auto indices_filler, MakeArrayFromScalar(*zero, 3));
ASSERT_OK_AND_ASSIGN(auto indices_sliced,
Concatenate({indices_filler, indices, indices_filler}));
indices_sliced = indices_sliced->Slice(3, indices->length());
AssertTakeArrays(values, indices_sliced, expected);
}

void CheckTake(const std::shared_ptr<DataType>& type, const std::string& values_json,
const std::string& indices_json, const std::string& expected_json) {
auto values = ArrayFromJSON(type, values_json);
auto expected = ArrayFromJSON(type, expected_json);

for (auto index_type : {int8(), uint32()}) {
auto indices = ArrayFromJSON(index_type, indices_json);
AssertTakeArrays(values, indices, expected);

// Check sliced values
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
AssertTakeArrays(values_sliced, indices, expected);

// Check sliced indices
ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0}));
ASSERT_OK_AND_ASSIGN(auto indices_filler, MakeArrayFromScalar(*zero, 3));
ASSERT_OK_AND_ASSIGN(auto indices_sliced,
Concatenate({indices_filler, indices, indices_filler}));
indices_sliced = indices_sliced->Slice(3, indices->length());
AssertTakeArrays(values, indices_sliced, expected);
DoCheckTake(values, indices, expected);
}
}

Expand Down

0 comments on commit 27ca4ea

Please sign in to comment.