Skip to content

Commit

Permalink
Take: Use PrimitiveTakeExec for fixed-width (modulo nesting) arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecrv committed Apr 19, 2024
1 parent 27ca4ea commit dbf733d
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 15 deletions.
27 changes: 27 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/bit_run_reader.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/fixed_width_internal.h"
#include "arrow/util/int_util.h"
#include "arrow/util/logging.h"
#include "arrow/util/ree_util.h"
Expand Down Expand Up @@ -950,6 +951,32 @@ Status LargeListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
}

Status FSLTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const ArraySpan& values = batch[0].array;

// If a FixedSizeList wraps a fixed-width type we can, in some cases, use
// PrimitiveTakeExec for a fixed-size list array.
if (util::IsFixedWidthModuloNesting(values,
/*force_null_count=*/true,
/*extra_predicate=*/[](auto& fixed_width_type) {
// DICTIONARY is fixed-width but not supported by
// PrimitiveTakeExec.
return fixed_width_type.id() != Type::DICTIONARY;
})) {
const auto byte_width = util::FixedWidthInBytes(*values.type);
// Additionally, PrimitiveTakeExec is only implemented for specific byte widths.
switch (byte_width) {
case 1:
case 2:
case 4:
case 8:
case 16:
case 32:
return PrimitiveTakeExec(ctx, batch, out);
default:
break; // fallback to TakeExec<FSBSelectionImpl>
}
}

return TakeExec<FSLSelectionImpl>(ctx, batch, out);
}

Expand Down
34 changes: 20 additions & 14 deletions cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ namespace {
using TakeState = OptionsWrapper<TakeOptions>;

// ----------------------------------------------------------------------
// Implement optimized take for primitive types from boolean to 1/2/4/8-byte
// Implement optimized take for primitive types from boolean to 1/2/4/8/16/32-byte
// C-type based types. Use common implementation for every byte width and only
// generate code for unsigned integer indices, since after boundschecking to
// check for negative numbers in the indices we can safely reinterpret_cast
Expand All @@ -334,33 +334,36 @@ using TakeState = OptionsWrapper<TakeOptions>;
/// use the logical Arrow type but rather the physical C type. This way we
/// only generate one take function for each byte width.
///
/// This function assumes that the indices have been boundschecked.
/// Also note that this function can also handle fixed-size-list arrays if
/// they fit the criteria described in fixed_width_internal.h, so use the
/// function defined in that file to access values and destination pointers
/// and DO NOT ASSUME `values.type()` is a primitive type.
///
/// \pre the indices have been boundschecked
template <typename IndexCType, typename ValueWidthConstant>
struct PrimitiveTakeImpl {
static constexpr int kValueWidth = ValueWidthConstant::value;

static void Exec(const ArraySpan& values, const ArraySpan& indices,
ArrayData* out_arr) {
DCHECK_EQ(values.type->byte_width(), kValueWidth);
const auto* values_data =
values.GetValues<uint8_t>(1, 0) + kValueWidth * values.offset;
DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth);
const auto* values_data = util::OffsetPointerOfFixedWidthValues(values);
const uint8_t* values_is_valid = values.buffers[0].data;
auto values_offset = values.offset;

const auto* indices_data = indices.GetValues<IndexCType>(1);
const uint8_t* indices_is_valid = indices.buffers[0].data;
auto indices_offset = indices.offset;

auto out = out_arr->GetMutableValues<uint8_t>(1, 0) + kValueWidth * out_arr->offset;
DCHECK_EQ(out_arr->offset, 0);
auto* out = util::MutableFixedWidthValuesPointer(out_arr);
auto out_is_valid = out_arr->buffers[0]->mutable_data();
auto out_offset = out_arr->offset;
DCHECK_EQ(out_offset, 0);

// If either the values or indices have nulls, we preemptively zero out the
// out validity bitmap so that we don't have to use ClearBit in each
// iteration for nulls.
if (values.null_count != 0 || indices.null_count != 0) {
bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false);
bit_util::SetBitsTo(out_is_valid, 0, indices.length, false);
}

auto WriteValue = [&](int64_t position) {
Expand All @@ -387,7 +390,7 @@ struct PrimitiveTakeImpl {
valid_count += block.popcount;
if (block.popcount == block.length) {
// Fastest path: neither values nor index nulls
bit_util::SetBitsTo(out_is_valid, out_offset + position, block.length, true);
bit_util::SetBitsTo(out_is_valid, position, block.length, true);
for (int64_t i = 0; i < block.length; ++i) {
WriteValue(position);
++position;
Expand All @@ -397,7 +400,7 @@ struct PrimitiveTakeImpl {
for (int64_t i = 0; i < block.length; ++i) {
if (bit_util::GetBit(indices_is_valid, indices_offset + position)) {
// index is not null
bit_util::SetBit(out_is_valid, out_offset + position);
bit_util::SetBit(out_is_valid, position);
WriteValue(position);
} else {
WriteZero(position);
Expand All @@ -417,7 +420,7 @@ struct PrimitiveTakeImpl {
values_offset + indices_data[position])) {
// value is not null
WriteValue(position);
bit_util::SetBit(out_is_valid, out_offset + position);
bit_util::SetBit(out_is_valid, position);
++valid_count;
} else {
WriteZero(position);
Expand All @@ -434,7 +437,7 @@ struct PrimitiveTakeImpl {
values_offset + indices_data[position])) {
// index is not null && value is not null
WriteValue(position);
bit_util::SetBit(out_is_valid, out_offset + position);
bit_util::SetBit(out_is_valid, position);
++valid_count;
} else {
WriteZero(position);
Expand Down Expand Up @@ -585,7 +588,10 @@ Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*

ArrayData* out_arr = out->array_data().get();

const int bit_width = values.type->bit_width();
DCHECK(util::IsFixedWidthModuloNesting(
values, /*force_null_count=*/false,
[](const auto& type) { return type.id() != Type::DICTIONARY; }));
const int64_t bit_width = util::FixedWidthInBits(*values.type);

// TODO: When neither values nor indices contain nulls, we can skip
// allocating the validity bitmap altogether and save time and space. A
Expand Down
96 changes: 95 additions & 1 deletion cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <utility>
#include <vector>

#include "arrow/array/builder_nested.h"
#include "arrow/array/concatenate.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api.h"
Expand All @@ -32,6 +33,7 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/util/fixed_width_test_util.h"
#include "arrow/util/logging.h"

namespace arrow {
Expand Down Expand Up @@ -1432,7 +1434,25 @@ TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) {
CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]");
}

class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped<FixedSizeListType> {};
class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped<FixedSizeListType> {
protected:
void CheckTakeOnNestedLists(const std::shared_ptr<DataType>& inner_type,
const std::vector<int>& list_sizes, int64_t length) {
using NLG = ::arrow::util::internal::NestedListGenerator;
// Create two equivalent lists: one as a FixedSizeList and another as a List.
ASSERT_OK_AND_ASSIGN(auto fsl_list,
NLG::NestedFSLArray(inner_type, list_sizes, length));
ASSERT_OK_AND_ASSIGN(auto list, NLG::NestedListArray(inner_type, list_sizes, length));

ARROW_SCOPED_TRACE("CheckTakeOnNestedLists of type `", *fsl_list->type(), "`");

auto indices = ArrayFromJSON(int64(), "[1, 2, 4]");
// Use the Take on ListType as the reference implementation.
ASSERT_OK_AND_ASSIGN(auto expected_list, Take(*list, *indices));
ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(*expected_list, fsl_list->type()));
DoCheckTake(fsl_list, indices, expected_fsl);
}
};

TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
Expand All @@ -1454,6 +1474,80 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
"[0, 1, 0]");
}

TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) {
std::string list_json =
R"([null, ["one", null, "three"], ["four", "five", "six"], ["seven", "eight", null]])";
CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]");
CheckTake(
fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]",
R"([["seven", "eight", null], ["four", "five", "six"], ["one", null, "three"]])");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]",
R"([null, ["four", "five", "six"], null])");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, null]", "[null, null]");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]",
R"([["seven", "eight", null], null, null, ["seven", "eight", null]])");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json);
CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
R"([
["four", "five", "six"], ["four", "five", "six"],
["four", "five", "six"], ["four", "five", "six"],
["four", "five", "six"], ["four", "five", "six"],
["one", null, "three"]
])");
}

TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListModuloNesting) {
const std::vector<std::shared_ptr<DataType>> value_types = {
int16(),
int32(),
int64(),
};
for (size_t desired_depth = 1; desired_depth <= 3; desired_depth++) {
ARROW_SCOPED_TRACE("nesting-depth = ", desired_depth);
for (auto& type : value_types) {
ARROW_SCOPED_TRACE("inner-type = ", *type);
int value_width = type->byte_width();

std::vector<int> list_sizes; // stack of list sizes
auto pop = [&]() { // pop the list_sizes stack
DCHECK(!list_sizes.empty());
value_width /= list_sizes.back();
list_sizes.pop_back();
};
auto next = [&]() { // double the top of the stack
DCHECK(!list_sizes.empty());
value_width *= 2;
list_sizes.back() *= 2;
return value_width;
};
auto push_1s = [&]() { // fill the stack with 1s
while (list_sizes.size() < desired_depth) {
list_sizes.push_back(1);
}
};

// Loop invariants:
// value_width == product(list_sizes) * type->byte_width()
// value_width is a power-of-2 (1, 2, 4, 8, 16, 32)
push_1s();
do {
// for (auto x : list_sizes) printf("%d * ", x);
// printf("(%s) %d = %2d\n", type->name().c_str(), type->byte_width(),
// value_width);
this->CheckTakeOnNestedLists(type, list_sizes, /*length=*/5);
// Advance to the next test case
while (!list_sizes.empty()) {
if (next() <= 32) {
push_1s();
break;
}
pop();
}
} while (!list_sizes.empty());
}
}
}

class TestTakeKernelWithMap : public TestTakeKernelTyped<MapType> {};

TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
Expand Down
Loading

0 comments on commit dbf733d

Please sign in to comment.