diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index 62b3a0400fabc..b9846516496a2 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -355,9 +355,9 @@ struct PrimitiveTakeImpl { const uint8_t* indices_is_valid = indices.buffers[0].data; auto indices_offset = indices.offset; + DCHECK_EQ(out_arr->offset, 0); auto* out = util::MutableFixedWidthValuesPointer(out_arr); auto out_is_valid = out_arr->buffers[0]->mutable_data(); - DCHECK_EQ(out_arr->offset, 0); // If either the values or indices have nulls, we preemptively zero out the // out validity bitmap so that we don't have to use ClearBit in each @@ -600,6 +600,7 @@ Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* RETURN_NOT_OK(util::internal::PreallocateFixedWidthArrayData( ctx, indices.length, /*source=*/values, /*allocate_validity=*/true, out_arr)); + DCHECK(util::MutableFixedWidthValuesPointer(out_arr)); switch (bit_width) { case 1: TakeIndexDispatch(values, indices, out_arr); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index aca13e782a705..e7cef08064db8 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -23,6 +23,7 @@ #include #include +#include "arrow/array/builder_nested.h" #include "arrow/array/concatenate.h" #include "arrow/chunked_array.h" #include "arrow/compute/api.h" @@ -32,6 +33,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/testing/util.h" +#include "arrow/util/fixed_width_test_util.h" #include "arrow/util/logging.h" namespace arrow { @@ -1454,6 +1456,91 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { "[0, 1, 0]"); } +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) { + std::string list_json = + R"([null, ["one", null, "three"], ["four", "five", "six"], ["seven", "eight", null]])"; + CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]"); + CheckTake( + fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]", + R"([["seven", "eight", null], ["four", "five", "six"], ["one", null, "three"]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]", + R"([null, ["four", "five", "six"], null])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, null]", "[null, null]"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]", + R"([["seven", "eight", null], null, null, ["seven", "eight", null]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]", + R"([ + ["four", "five", "six"], ["four", "five", "six"], + ["four", "five", "six"], ["four", "five", "six"], + ["four", "five", "six"], ["four", "five", "six"], + ["one", null, "three"] + ])"); +} + +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListModuloNesting) { + using NLG = ::arrow::util::internal::NestedListGenerator; + + auto CheckTakeOnNestedLists = [](const std::shared_ptr& inner_type, + const std::vector& list_sizes) -> void { + constexpr int64_t kLength = 5; + // Create two equivalent lists: one as a FixedSizeList and another as a List. + ASSERT_OK_AND_ASSIGN(auto fsl_list, + NLG::NestedFSLArray(inner_type, list_sizes, kLength)); + ASSERT_OK_AND_ASSIGN(auto list, + NLG::NestedListArray(inner_type, list_sizes, kLength)); + + auto indices = ArrayFromJSON(int64(), "[1, 2, 4]"); + // Use the Take on ListType as the reference implementation. + ASSERT_OK_AND_ASSIGN(auto expected_list, Take(*list, *indices)); + ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(*expected_list, fsl_list->type())); + DoCheckTake(fsl_list, indices, expected_fsl); + }; + + const std::vector> value_types = { + int8(), + int16(), + int32(), + int64(), + }; + for (auto& value_type : value_types) { + ARROW_SCOPED_TRACE("Nested fixed_size_list: inner-type = ", *value_type); + for (size_t desired_depth = 1; desired_depth <= 3; desired_depth++) { + ARROW_SCOPED_TRACE("desired nesting depth = ", desired_depth); + int value_width = value_type->byte_width(); + std::vector list_sizes; + auto push = [&](int list_size) { + value_width *= list_size; + list_sizes.push_back(list_size); + }; + auto pop = [&]() { + value_width /= list_sizes.back(); + list_sizes.pop_back(); + }; + auto next = [&]() { + value_width *= 2; + list_sizes.back() *= 2; + return value_width; + }; + // invariant: value_width == product(list_sizes) * value_type->byte_width() + // invariant: value_width is a power of 2 (1, 2, 4, 8, 16, 32) + do { + while (list_sizes.size() < desired_depth) { + push(1); + } + for (auto x : list_sizes) { + printf("%d ", x); + } + printf("= %2d\n", value_width); + CheckTakeOnNestedLists(value_type, list_sizes); + while (next() > 32 && !list_sizes.empty()) { + pop(); + } + } while (!list_sizes.empty()); + } + } +} + class TestTakeKernelWithMap : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { diff --git a/cpp/src/arrow/util/fixed_width_internal.cc b/cpp/src/arrow/util/fixed_width_internal.cc index 1cd766e396102..8ca6cecc46eb6 100644 --- a/cpp/src/arrow/util/fixed_width_internal.cc +++ b/cpp/src/arrow/util/fixed_width_internal.cc @@ -88,6 +88,7 @@ Status PreallocateFixedWidthArrayData(::arrow::compute::KernelContext* ctx, } if (type->id() == Type::FIXED_SIZE_LIST) { auto& fsl_type = checked_cast(*type); + auto* values = &source.child_data[0]; auto& value_type = fsl_type.value_type(); if (is_fixed_width(value_type->id())) { if (value_type->id() == Type::BOOL) { @@ -98,20 +99,18 @@ Status PreallocateFixedWidthArrayData(::arrow::compute::KernelContext* ctx, return Status::NotImplemented( "PreallocateFixedWidthArrayData: DICTIONARY type allocation: ", *type); } - auto* values = &source.child_data[0]; if (values->MayHaveNulls()) { return Status::Invalid( "PreallocateFixedWidthArrayData: " "FixedSizeList may have null values in child array: ", fsl_type); } - auto allocated_values = std::make_shared(); - allocated_values->type = fsl_type.value_type(); - RETURN_NOT_OK(PreallocateFixedWidthArrayData( - ctx, length * fsl_type.list_size(), *values, - /*allocate_validity=*/false, allocated_values.get())); - out->child_data.resize(1); - out->child_data[0] = std::move(allocated_values); + auto* children = out->child_data.emplace_back(std::make_shared()).get(); + children->type = value_type; + return PreallocateFixedWidthArrayData(ctx, length * fsl_type.list_size(), + /*source=*/*values, + /*allocate_validity=*/false, + /*out=*/children); } return Status::OK(); } diff --git a/cpp/src/arrow/util/fixed_width_test_util.h b/cpp/src/arrow/util/fixed_width_test_util.h new file mode 100644 index 0000000000000..0f9a05a094ab2 --- /dev/null +++ b/cpp/src/arrow/util/fixed_width_test_util.h @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array/builder_primitive.h" +#include "arrow/builder.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +namespace arrow::util::internal { + +class NestedListGenerator { + public: + /// \brief Create a nested FixedSizeListType. + /// + /// \return `fixed_size_list(fixed_size_list(..., sizes[1]), sizes[0])` + static std::shared_ptr NestedFSLType( + const std::shared_ptr& inner_type, const std::vector& sizes) { + auto fsl_type = fixed_size_list(inner_type, sizes[0]); + for (size_t i = 1; i < sizes.size(); i++) { + fsl_type = fixed_size_list(std::move(fsl_type), sizes[i]); + } + return fsl_type; + } + + /// \brief Create a nested FixedListType. + /// + /// \return `list(list(...))` + static std::shared_ptr NestedListType( + const std::shared_ptr& inner_type, size_t depth) { + auto list_type = list(inner_type); + for (size_t i = 1; i < depth; i++) { + list_type = list(std::move(list_type)); + } + return list_type; + } + + private: + template + static Status AppendNumeric(ArrayBuilder* builder, int64_t* next_value) { + using NumericBuilder = ::arrow::NumericBuilder; + using value_type = typename NumericBuilder::value_type; + auto* numeric_builder = ::arrow::internal::checked_cast(builder); + auto cast_next_value = + static_cast(*next_value % std::numeric_limits::max()); + RETURN_NOT_OK(numeric_builder->Append(cast_next_value)); + *next_value += 1; + return Status::OK(); + } + + // Append([...[[*next_inner_value++, *next_inner_value++, ...]]...]) + static Status AppendNestedList(ArrayBuilder* nested_builder, const int* list_sizes, + int64_t* next_inner_value) { + using ::arrow::internal::checked_cast; + ArrayBuilder* builder = nested_builder; + auto type = builder->type(); + if (type->id() == Type::FIXED_SIZE_LIST || type->id() == Type::LIST) { + const int list_size = *list_sizes; + if (type->id() == Type::FIXED_SIZE_LIST) { + auto* fsl_builder = checked_cast(builder); + assert(list_size == checked_cast(*type).list_size()); + RETURN_NOT_OK(fsl_builder->Append()); + builder = fsl_builder->value_builder(); + } else { // type->id() == Type::LIST) + auto* list_builder = checked_cast(builder); + RETURN_NOT_OK(list_builder->Append(/*is_valid=*/true, list_size)); + builder = list_builder->value_builder(); + } + for (int i = 0; i < list_size; i++) { + RETURN_NOT_OK(AppendNestedList(builder, list_sizes + 1, next_inner_value)); + } + } else { + switch (type->id()) { + case Type::INT8: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + case Type::INT16: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + case Type::INT32: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + case Type::INT64: + RETURN_NOT_OK(AppendNumeric(builder, next_inner_value)); + break; + default: + return Status::NotImplemented("Unsupported type: ", *type); + } + } + return Status::OK(); + } + + static Result> NestedListArray( + ArrayBuilder* nested_builder, const std::vector& list_sizes, int64_t length) { + int64_t next_inner_value = 0; + for (int64_t i = 0; i < length; i++) { + RETURN_NOT_OK( + AppendNestedList(nested_builder, list_sizes.data(), &next_inner_value)); + } + return nested_builder->Finish(); + } + + public: + static Result> NestedFSLArray( + const std::shared_ptr& inner_type, const std::vector& list_sizes, + int64_t length) { + auto nested_type = NestedFSLType(inner_type, list_sizes); + ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(nested_type)); + return NestedListArray(builder.get(), list_sizes, length); + } + + static Result> NestedListArray( + const std::shared_ptr& inner_type, const std::vector& list_sizes, + int64_t length) { + auto nested_type = NestedListType(inner_type, list_sizes.size()); + ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(nested_type)); + return NestedListArray(builder.get(), list_sizes, length); + } +}; + +} // namespace arrow::util::internal