diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 00310d6bec908..63e41c70a7f09 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -37,6 +37,7 @@ #include "arrow/util/bit_block_counter.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/bit_util.h" +#include "arrow/util/fixed_width_internal.h" #include "arrow/util/int_util.h" #include "arrow/util/logging.h" #include "arrow/util/ree_util.h" @@ -950,6 +951,32 @@ Status LargeListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* } Status FSLTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const ArraySpan& values = batch[0].array; + + // If a FixedSizeList wraps a fixed-width type we can, in some cases, use + // PrimitiveTakeExec for a fixed-size list array. + if (util::IsFixedWidthModuloNesting(values, + /*force_null_count=*/true, + /*extra_predicate=*/[](auto& fixed_width_type) { + // DICTIONARY is fixed-width but not supported by + // PrimitiveTakeExec. + return fixed_width_type.id() != Type::DICTIONARY; + })) { + const auto byte_width = util::FixedWidthInBytes(*values.type); + // Additionally, PrimitiveTakeExec is only implemented for specific byte widths. + switch (byte_width) { + case 1: + case 2: + case 4: + case 8: + case 16: + case 32: + return PrimitiveTakeExec(ctx, batch, out); + default: + break; // fallback to TakeExec + } + } + return TakeExec(ctx, batch, out); } diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index 4b58681817b08..e907bda54d4ff 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -324,7 +324,7 @@ namespace { using TakeState = OptionsWrapper; // ---------------------------------------------------------------------- -// Implement optimized take for primitive types from boolean to 1/2/4/8-byte +// Implement optimized take for primitive types from boolean to 1/2/4/8/16/32-byte // C-type based types. Use common implementation for every byte width and only // generate code for unsigned integer indices, since after boundschecking to // check for negative numbers in the indices we can safely reinterpret_cast @@ -334,16 +334,20 @@ using TakeState = OptionsWrapper; /// use the logical Arrow type but rather the physical C type. This way we /// only generate one take function for each byte width. /// -/// This function assumes that the indices have been boundschecked. +/// Also note that this function can also handle fixed-size-list arrays if +/// they fit the criteria described in fixed_width_internal.h, so use the +/// function defined in that file to access values and destination pointers +/// and DO NOT ASSUME `values.type()` is a primitive type. +/// +/// \pre the indices have been boundschecked template struct PrimitiveTakeImpl { static constexpr int kValueWidth = ValueWidthConstant::value; static void Exec(const ArraySpan& values, const ArraySpan& indices, ArrayData* out_arr) { - DCHECK_EQ(values.type->byte_width(), kValueWidth); - const auto* values_data = - values.GetValues(1, 0) + kValueWidth * values.offset; + DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth); + const auto* values_data = util::OffsetPointerOfFixedWidthValues(values); const uint8_t* values_is_valid = values.buffers[0].data; auto values_offset = values.offset; @@ -351,16 +355,15 @@ struct PrimitiveTakeImpl { const uint8_t* indices_is_valid = indices.buffers[0].data; auto indices_offset = indices.offset; - auto out = out_arr->GetMutableValues(1, 0) + kValueWidth * out_arr->offset; + DCHECK_EQ(out_arr->offset, 0); + auto* out = util::MutableFixedWidthValuesPointer(out_arr); auto out_is_valid = out_arr->buffers[0]->mutable_data(); - auto out_offset = out_arr->offset; - DCHECK_EQ(out_offset, 0); // If either the values or indices have nulls, we preemptively zero out the // out validity bitmap so that we don't have to use ClearBit in each // iteration for nulls. if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false); + bit_util::SetBitsTo(out_is_valid, 0, indices.length, false); } auto WriteValue = [&](int64_t position) { @@ -387,7 +390,7 @@ struct PrimitiveTakeImpl { valid_count += block.popcount; if (block.popcount == block.length) { // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, out_offset + position, block.length, true); + bit_util::SetBitsTo(out_is_valid, position, block.length, true); for (int64_t i = 0; i < block.length; ++i) { WriteValue(position); ++position; @@ -397,7 +400,7 @@ struct PrimitiveTakeImpl { for (int64_t i = 0; i < block.length; ++i) { if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { // index is not null - bit_util::SetBit(out_is_valid, out_offset + position); + bit_util::SetBit(out_is_valid, position); WriteValue(position); } else { WriteZero(position); @@ -417,7 +420,7 @@ struct PrimitiveTakeImpl { values_offset + indices_data[position])) { // value is not null WriteValue(position); - bit_util::SetBit(out_is_valid, out_offset + position); + bit_util::SetBit(out_is_valid, position); ++valid_count; } else { WriteZero(position); @@ -434,7 +437,7 @@ struct PrimitiveTakeImpl { values_offset + indices_data[position])) { // index is not null && value is not null WriteValue(position); - bit_util::SetBit(out_is_valid, out_offset + position); + bit_util::SetBit(out_is_valid, position); ++valid_count; } else { WriteZero(position); @@ -585,7 +588,10 @@ Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* ArrayData* out_arr = out->array_data().get(); - const int bit_width = values.type->bit_width(); + DCHECK(util::IsFixedWidthModuloNesting( + values, /*force_null_count=*/false, + [](const auto& type) { return type.id() != Type::DICTIONARY; })); + const int64_t bit_width = util::FixedWidthInBits(*values.type); // TODO: When neither values nor indices contain nulls, we can skip // allocating the validity bitmap altogether and save time and space. A diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index aca13e782a705..b7dde4b02c957 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -23,6 +23,7 @@ #include #include +#include "arrow/array/builder_nested.h" #include "arrow/array/concatenate.h" #include "arrow/chunked_array.h" #include "arrow/compute/api.h" @@ -32,6 +33,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/testing/util.h" +#include "arrow/util/fixed_width_test_util.h" #include "arrow/util/logging.h" namespace arrow { @@ -1432,7 +1434,25 @@ TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) { CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]"); } -class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped {}; +class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped { + protected: + void CheckTakeOnNestedLists(const std::shared_ptr& inner_type, + const std::vector& list_sizes, int64_t length) { + using NLG = ::arrow::util::internal::NestedListGenerator; + // Create two equivalent lists: one as a FixedSizeList and another as a List. + ASSERT_OK_AND_ASSIGN(auto fsl_list, + NLG::NestedFSLArray(inner_type, list_sizes, length)); + ASSERT_OK_AND_ASSIGN(auto list, NLG::NestedListArray(inner_type, list_sizes, length)); + + ARROW_SCOPED_TRACE("CheckTakeOnNestedLists of type `", *fsl_list->type(), "`"); + + auto indices = ArrayFromJSON(int64(), "[1, 2, 4]"); + // Use the Take on ListType as the reference implementation. + ASSERT_OK_AND_ASSIGN(auto expected_list, Take(*list, *indices)); + ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(*expected_list, fsl_list->type())); + DoCheckTake(fsl_list, indices, expected_fsl); + } +}; TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; @@ -1454,6 +1474,80 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { "[0, 1, 0]"); } +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) { + std::string list_json = + R"([null, ["one", null, "three"], ["four", "five", "six"], ["seven", "eight", null]])"; + CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]"); + CheckTake( + fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]", + R"([["seven", "eight", null], ["four", "five", "six"], ["one", null, "three"]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]", + R"([null, ["four", "five", "six"], null])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, null]", "[null, null]"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]", + R"([["seven", "eight", null], null, null, ["seven", "eight", null]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]", + R"([ + ["four", "five", "six"], ["four", "five", "six"], + ["four", "five", "six"], ["four", "five", "six"], + ["four", "five", "six"], ["four", "five", "six"], + ["one", null, "three"] + ])"); +} + +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListModuloNesting) { + const std::vector> value_types = { + int16(), + int32(), + int64(), + }; + for (size_t desired_depth = 1; desired_depth <= 3; desired_depth++) { + ARROW_SCOPED_TRACE("nesting-depth = ", desired_depth); + for (auto& type : value_types) { + ARROW_SCOPED_TRACE("inner-type = ", *type); + int value_width = type->byte_width(); + + std::vector list_sizes; // stack of list sizes + auto pop = [&]() { // pop the list_sizes stack + DCHECK(!list_sizes.empty()); + value_width /= list_sizes.back(); + list_sizes.pop_back(); + }; + auto next = [&]() { // double the top of the stack + DCHECK(!list_sizes.empty()); + value_width *= 2; + list_sizes.back() *= 2; + return value_width; + }; + auto push_1s = [&]() { // fill the stack with 1s + while (list_sizes.size() < desired_depth) { + list_sizes.push_back(1); + } + }; + + // Loop invariants: + // value_width == product(list_sizes) * type->byte_width() + // value_width is a power-of-2 (1, 2, 4, 8, 16, 32) + push_1s(); + do { + // for (auto x : list_sizes) printf("%d * ", x); + // printf("(%s) %d = %2d\n", type->name().c_str(), type->byte_width(), + // value_width); + this->CheckTakeOnNestedLists(type, list_sizes, /*length=*/5); + // Advance to the next test case + while (!list_sizes.empty()) { + if (next() <= 32) { + push_1s(); + break; + } + pop(); + } + } while (!list_sizes.empty()); + } + } +} + class TestTakeKernelWithMap : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { diff --git a/cpp/src/arrow/util/fixed_width_test_util.h b/cpp/src/arrow/util/fixed_width_test_util.h new file mode 100644 index 0000000000000..74e653ace3bcd --- /dev/null +++ b/cpp/src/arrow/util/fixed_width_test_util.h @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array/builder_primitive.h" +#include "arrow/builder.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +namespace arrow::util::internal { + +class NestedListGenerator { + public: + /// \brief Create a nested FixedSizeListType. + /// + /// \return `fixed_size_list(fixed_size_list(..., sizes[1]), sizes[0])` + static std::shared_ptr NestedFSLType( + const std::shared_ptr& inner_type, const std::vector& sizes) { + auto type = inner_type; + for (auto it = sizes.rbegin(); it != sizes.rend(); it++) { + type = fixed_size_list(std::move(type), *it); + } + return type; + } + + /// \brief Create a nested FixedListType. + /// + /// \return `list(list(...))` + static std::shared_ptr NestedListType( + const std::shared_ptr& inner_type, size_t depth) { + auto list_type = list(inner_type); + for (size_t i = 1; i < depth; i++) { + list_type = list(std::move(list_type)); + } + return list_type; + } + + private: + template + static Status AppendNumeric(ArrayBuilder* builder, int64_t* next_value) { + using NumericBuilder = ::arrow::NumericBuilder; + using value_type = typename NumericBuilder::value_type; + auto* numeric_builder = ::arrow::internal::checked_cast(builder); + auto cast_next_value = + static_cast(*next_value % std::numeric_limits::max()); + RETURN_NOT_OK(numeric_builder->Append(cast_next_value)); + *next_value += 1; + return Status::OK(); + } + + // Append([...[[*next_inner_value++, *next_inner_value++, ...]]...]) + static Status AppendNestedList(ArrayBuilder* nested_builder, const int* list_sizes, + int64_t* next_inner_value) { + using ::arrow::internal::checked_cast; + ArrayBuilder* builder = nested_builder; + auto type = builder->type(); + if (type->id() == Type::FIXED_SIZE_LIST || type->id() == Type::LIST) { + const int list_size = *list_sizes; + if (type->id() == Type::FIXED_SIZE_LIST) { + auto* fsl_builder = checked_cast(builder); + assert(list_size == checked_cast(*type).list_size()); + RETURN_NOT_OK(fsl_builder->Append()); + builder = fsl_builder->value_builder(); + } else { // type->id() == Type::LIST) + auto* list_builder = checked_cast(builder); + RETURN_NOT_OK(list_builder->Append(/*is_valid=*/true, list_size)); + builder = list_builder->value_builder(); + } + list_sizes++; + for (int i = 0; i < list_size; i++) { + RETURN_NOT_OK(AppendNestedList(builder, list_sizes, next_inner_value)); + } + } else { + switch (type->id()) { + case Type::INT8: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + case Type::INT16: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + case Type::INT32: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + case Type::INT64: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + default: + return Status::NotImplemented("Unsupported type: ", *type); + } + } + return Status::OK(); + } + + static Result> NestedListArray( + ArrayBuilder* nested_builder, const std::vector& list_sizes, int64_t length) { + int64_t next_inner_value = 0; + for (int64_t i = 0; i < length; i++) { + RETURN_NOT_OK( + AppendNestedList(nested_builder, list_sizes.data(), &next_inner_value)); + } + return nested_builder->Finish(); + } + + public: + static Result> NestedFSLArray( + const std::shared_ptr& inner_type, const std::vector& list_sizes, + int64_t length) { + auto nested_type = NestedFSLType(inner_type, list_sizes); + ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(nested_type)); + return NestedListArray(builder.get(), list_sizes, length); + } + + static Result> NestedListArray( + const std::shared_ptr& inner_type, const std::vector& list_sizes, + int64_t length) { + auto nested_type = NestedListType(inner_type, list_sizes.size()); + ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(nested_type)); + return NestedListArray(builder.get(), list_sizes, length); + } +}; + +} // namespace arrow::util::internal