Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecrv committed Apr 18, 2024
1 parent 8b08716 commit eebb793
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ struct PrimitiveTakeImpl {
const uint8_t* indices_is_valid = indices.buffers[0].data;
auto indices_offset = indices.offset;

DCHECK_EQ(out_arr->offset, 0);
auto* out = util::MutableFixedWidthValuesPointer(out_arr);
auto out_is_valid = out_arr->buffers[0]->mutable_data();
DCHECK_EQ(out_arr->offset, 0);

// If either the values or indices have nulls, we preemptively zero out the
// out validity bitmap so that we don't have to use ClearBit in each
Expand Down Expand Up @@ -600,6 +600,7 @@ Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
RETURN_NOT_OK(util::internal::PreallocateFixedWidthArrayData(
ctx, indices.length, /*source=*/values,
/*allocate_validity=*/true, out_arr));
DCHECK(util::MutableFixedWidthValuesPointer(out_arr));
switch (bit_width) {
case 1:
TakeIndexDispatch<BooleanTakeImpl>(values, indices, out_arr);
Expand Down
87 changes: 87 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <utility>
#include <vector>

#include "arrow/array/builder_nested.h"
#include "arrow/array/concatenate.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api.h"
Expand All @@ -32,6 +33,7 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/util/fixed_width_test_util.h"
#include "arrow/util/logging.h"

namespace arrow {
Expand Down Expand Up @@ -1454,6 +1456,91 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
"[0, 1, 0]");
}

TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) {
std::string list_json =
R"([null, ["one", null, "three"], ["four", "five", "six"], ["seven", "eight", null]])";
CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]");
CheckTake(
fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]",
R"([["seven", "eight", null], ["four", "five", "six"], ["one", null, "three"]])");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]",
R"([null, ["four", "five", "six"], null])");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, null]", "[null, null]");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]",
R"([["seven", "eight", null], null, null, ["seven", "eight", null]])");
CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json);
CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
R"([
["four", "five", "six"], ["four", "five", "six"],
["four", "five", "six"], ["four", "five", "six"],
["four", "five", "six"], ["four", "five", "six"],
["one", null, "three"]
])");
}

TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListModuloNesting) {
using NLG = ::arrow::util::internal::NestedListGenerator;

auto CheckTakeOnNestedLists = [](const std::shared_ptr<DataType>& inner_type,
const std::vector<int>& list_sizes) -> void {
constexpr int64_t kLength = 5;
// Create two equivalent lists: one as a FixedSizeList and another as a List.
ASSERT_OK_AND_ASSIGN(auto fsl_list,
NLG::NestedFSLArray(inner_type, list_sizes, kLength));
ASSERT_OK_AND_ASSIGN(auto list,
NLG::NestedListArray(inner_type, list_sizes, kLength));

auto indices = ArrayFromJSON(int64(), "[1, 2, 4]");
// Use the Take on ListType as the reference implementation.
ASSERT_OK_AND_ASSIGN(auto expected_list, Take(*list, *indices));
ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(*expected_list, fsl_list->type()));
DoCheckTake(fsl_list, indices, expected_fsl);
};

const std::vector<std::shared_ptr<DataType>> value_types = {
int8(),
int16(),
int32(),
int64(),
};
for (auto& value_type : value_types) {
ARROW_SCOPED_TRACE("Nested fixed_size_list: inner-type = ", *value_type);
for (size_t desired_depth = 1; desired_depth <= 3; desired_depth++) {
ARROW_SCOPED_TRACE("desired nesting depth = ", desired_depth);
int value_width = value_type->byte_width();
std::vector<int> list_sizes;
auto push = [&](int list_size) {
value_width *= list_size;
list_sizes.push_back(list_size);
};
auto pop = [&]() {
value_width /= list_sizes.back();
list_sizes.pop_back();
};
auto next = [&]() {
value_width *= 2;
list_sizes.back() *= 2;
return value_width;
};
// invariant: value_width == product(list_sizes) * value_type->byte_width()
// invariant: value_width is a power of 2 (1, 2, 4, 8, 16, 32)
do {
while (list_sizes.size() < desired_depth) {
push(1);
}
for (auto x : list_sizes) {
printf("%d ", x);
}
printf("= %2d\n", value_width);
CheckTakeOnNestedLists(value_type, list_sizes);
while (next() > 32 && !list_sizes.empty()) {
pop();
}
} while (!list_sizes.empty());
}
}
}

class TestTakeKernelWithMap : public TestTakeKernelTyped<MapType> {};

TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
Expand Down
15 changes: 7 additions & 8 deletions cpp/src/arrow/util/fixed_width_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Status PreallocateFixedWidthArrayData(::arrow::compute::KernelContext* ctx,
}
if (type->id() == Type::FIXED_SIZE_LIST) {
auto& fsl_type = checked_cast<const FixedSizeListType&>(*type);
auto* values = &source.child_data[0];
auto& value_type = fsl_type.value_type();
if (is_fixed_width(value_type->id())) {
if (value_type->id() == Type::BOOL) {
Expand All @@ -98,20 +99,18 @@ Status PreallocateFixedWidthArrayData(::arrow::compute::KernelContext* ctx,
return Status::NotImplemented(
"PreallocateFixedWidthArrayData: DICTIONARY type allocation: ", *type);
}
auto* values = &source.child_data[0];
if (values->MayHaveNulls()) {
return Status::Invalid(
"PreallocateFixedWidthArrayData: "
"FixedSizeList may have null values in child array: ",
fsl_type);
}
auto allocated_values = std::make_shared<ArrayData>();
allocated_values->type = fsl_type.value_type();
RETURN_NOT_OK(PreallocateFixedWidthArrayData(
ctx, length * fsl_type.list_size(), *values,
/*allocate_validity=*/false, allocated_values.get()));
out->child_data.resize(1);
out->child_data[0] = std::move(allocated_values);
auto* children = out->child_data.emplace_back(std::make_shared<ArrayData>()).get();
children->type = value_type;
return PreallocateFixedWidthArrayData(ctx, length * fsl_type.list_size(),
/*source=*/*values,
/*allocate_validity=*/false,
/*out=*/children);
}
return Status::OK();
}
Expand Down
141 changes: 141 additions & 0 deletions cpp/src/arrow/util/fixed_width_test_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <cstddef>
#include <cstdint>
#include <memory>
#include <vector>

#include "arrow/array/builder_primitive.h"
#include "arrow/builder.h"
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"

namespace arrow::util::internal {

class NestedListGenerator {
public:
/// \brief Create a nested FixedSizeListType.
///
/// \return `fixed_size_list(fixed_size_list(..., sizes[1]), sizes[0])`
static std::shared_ptr<DataType> NestedFSLType(
const std::shared_ptr<DataType>& inner_type, const std::vector<int>& sizes) {
auto fsl_type = fixed_size_list(inner_type, sizes[0]);
for (size_t i = 1; i < sizes.size(); i++) {
fsl_type = fixed_size_list(std::move(fsl_type), sizes[i]);
}
return fsl_type;
}

/// \brief Create a nested FixedListType.
///
/// \return `list(list(...))`
static std::shared_ptr<DataType> NestedListType(
const std::shared_ptr<DataType>& inner_type, size_t depth) {
auto list_type = list(inner_type);
for (size_t i = 1; i < depth; i++) {
list_type = list(std::move(list_type));
}
return list_type;
}

private:
template <typename ArrowType>
static Status AppendNumeric(ArrayBuilder* builder, int64_t* next_value) {
using NumericBuilder = ::arrow::NumericBuilder<ArrowType>;
using value_type = typename NumericBuilder::value_type;
auto* numeric_builder = ::arrow::internal::checked_cast<NumericBuilder*>(builder);
auto cast_next_value =
static_cast<value_type>(*next_value % std::numeric_limits<value_type>::max());
RETURN_NOT_OK(numeric_builder->Append(cast_next_value));
*next_value += 1;
return Status::OK();
}

// Append([...[[*next_inner_value++, *next_inner_value++, ...]]...])
static Status AppendNestedList(ArrayBuilder* nested_builder, const int* list_sizes,
int64_t* next_inner_value) {
using ::arrow::internal::checked_cast;
ArrayBuilder* builder = nested_builder;
auto type = builder->type();
if (type->id() == Type::FIXED_SIZE_LIST || type->id() == Type::LIST) {
const int list_size = *list_sizes;
if (type->id() == Type::FIXED_SIZE_LIST) {
auto* fsl_builder = checked_cast<FixedSizeListBuilder*>(builder);
assert(list_size == checked_cast<FixedSizeListType&>(*type).list_size());
RETURN_NOT_OK(fsl_builder->Append());
builder = fsl_builder->value_builder();
} else { // type->id() == Type::LIST)
auto* list_builder = checked_cast<ListBuilder*>(builder);
RETURN_NOT_OK(list_builder->Append(/*is_valid=*/true, list_size));
builder = list_builder->value_builder();
}
for (int i = 0; i < list_size; i++) {
RETURN_NOT_OK(AppendNestedList(builder, list_sizes + 1, next_inner_value));
}
} else {
switch (type->id()) {
case Type::INT8:
RETURN_NOT_OK(AppendNumeric<Int8Type>(builder, next_inner_value));
break;
case Type::INT16:
RETURN_NOT_OK(AppendNumeric<Int16Type>(builder, next_inner_value));
break;
case Type::INT32:
RETURN_NOT_OK(AppendNumeric<Int32Type>(builder, next_inner_value));
break;
case Type::INT64:
RETURN_NOT_OK(AppendNumeric<Int64Type>(builder, next_inner_value));
break;
default:
return Status::NotImplemented("Unsupported type: ", *type);
}
}
return Status::OK();
}

static Result<std::shared_ptr<Array>> NestedListArray(
ArrayBuilder* nested_builder, const std::vector<int>& list_sizes, int64_t length) {
int64_t next_inner_value = 0;
for (int64_t i = 0; i < length; i++) {
RETURN_NOT_OK(
AppendNestedList(nested_builder, list_sizes.data(), &next_inner_value));
}
return nested_builder->Finish();
}

public:
static Result<std::shared_ptr<Array>> NestedFSLArray(
const std::shared_ptr<DataType>& inner_type, const std::vector<int>& list_sizes,
int64_t length) {
auto nested_type = NestedFSLType(inner_type, list_sizes);
ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(nested_type));
return NestedListArray(builder.get(), list_sizes, length);
}

static Result<std::shared_ptr<Array>> NestedListArray(
const std::shared_ptr<DataType>& inner_type, const std::vector<int>& list_sizes,
int64_t length) {
auto nested_type = NestedListType(inner_type, list_sizes.size());
ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(nested_type));
return NestedListArray(builder.get(), list_sizes, length);
}
};

} // namespace arrow::util::internal

0 comments on commit eebb793

Please sign in to comment.