Skip to content

Commit

Permalink
Add if-else tests using permute
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 13, 2024
1 parent 3e438e8 commit 0811b2b
Showing 1 changed file with 197 additions and 28 deletions.
225 changes: 197 additions & 28 deletions cpp/src/arrow/compute/kernels/vector_placement_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/util/logging.h"

namespace arrow::compute {

Expand All @@ -34,6 +36,14 @@ static const std::vector<std::shared_ptr<DataType>> kSignedIntegerTypes = {
static const std::vector<std::shared_ptr<DataType>> kIntegerTypes = {
int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()};

static const std::vector<std::shared_ptr<DataType>> kNumericTypes = {
uint8(), int8(), uint16(), int16(), uint32(),
int32(), uint64(), int64(), float32(), float64()};

static const std::vector<std::shared_ptr<DataType>> kNumericAndBaseBinaryTypes = {
uint8(), int8(), uint16(), int16(), uint32(), int32(), uint64(),
int64(), float32(), float64(), binary(), utf8(), large_binary(), large_utf8()};

using SmallOutputTypes = ::testing::Types<UInt8Type, UInt16Type, Int8Type, Int16Type>;

} // namespace
Expand Down Expand Up @@ -490,33 +500,6 @@ void DoTestPermute(const std::shared_ptr<Array>& values,
DoTestPermuteForIndicesTypes(kIntegerTypes, values, indices, output_length, expected);
}

// void TestPermute(const std::shared_ptr<DataType>& values_type,
// const std::string& values_str, const std::string& indices_str,
// int64_t output_length, const std::string& expected_str) {
// auto values = ArrayFromJSON(values_type, values_str);
// auto expected = ArrayFromJSON(values_type, expected_str);
// for (const auto& indices_type : kIntegerTypes) {
// ARROW_SCOPED_TRACE("Indices type: " + indices_type->ToString());
// auto indices = ArrayFromJSON(indices_type, indices_str);
// {
// ARROW_SCOPED_TRACE("AAA");
// DoTestPermuteAAA(values, indices, output_length, expected);
// }
// {
// ARROW_SCOPED_TRACE("CAA");
// DoTestPermuteCACWithArrays(values, indices, output_length, expected);
// }
// {
// ARROW_SCOPED_TRACE("ACA");
// DoTestPermuteACCWithArrays(values, indices, output_length, expected);
// }
// {
// ARROW_SCOPED_TRACE("CCA");
// DoTestPermuteCCCWithArrays(values, indices, output_length, expected);
// }
// }
// }

} // namespace

TEST(Permute, Invalid) {
Expand Down Expand Up @@ -928,4 +911,190 @@ TYPED_TEST(TestPermuteString, Basic) {
}
}

}; // namespace arrow::compute
// ----------------------------------------------------------------------
// Test Permute using a hypothetical if-else special form.
// Also demonstrate how Permute can serve as a building block of implementing special
// forms.

namespace {

/// Execute an if-else expression using regular expression evaluation, as a reference.
Result<Datum> ExecuteIfElseByExpr(const Expression& cond, const Expression& if_true,
const Expression& if_false,
const std::shared_ptr<Schema>& schema,
const ExecBatch& input) {
auto if_else = call("if_else", {cond, if_true, if_false});
ARROW_ASSIGN_OR_RAISE(auto bound, if_else.Bind(*schema));
return ExecuteScalarExpression(bound, input);
}

/// Execute an if-else expression in a special form fashion, in which Permute is used as a
/// building block.
Result<Datum> ExecuteIfElseByPermute(const Expression& cond, const Expression& if_true,
const Expression& if_false,
const std::shared_ptr<Schema>& schema,
const ExecBatch& input) {
for (const auto& column : input.values) {
DCHECK(column.is_array());
}

ARROW_ASSIGN_OR_RAISE(auto input_rb, input.ToRecordBatch(schema));

// 1. Evaluate "cond", getting a boolean array as a mask to branches.
ARROW_ASSIGN_OR_RAISE(auto bound_cond, cond.Bind(*schema));
ARROW_ASSIGN_OR_RAISE(auto cond_datum, ExecuteScalarExpression(bound_cond, input));

// 2. Get indices of "true"s from the mask as the selection vector.
ARROW_ASSIGN_OR_RAISE(auto sel_if_true_datum,
CallFunction("indices_nonzero", {cond_datum}));
DCHECK(sel_if_true_datum.is_array());
auto sel_if_true_array = sel_if_true_datum.make_array();

// 3. Take the "true" rows from input.
ARROW_ASSIGN_OR_RAISE(auto if_true_input_datum,
CallFunction("take", {input_rb, sel_if_true_datum}));

// 4. Get indices of "false"es form the mas as the selection vector - by first inverting
// the mask and then getting the non-zero's indices.
ARROW_ASSIGN_OR_RAISE(auto invert_cond_datum, CallFunction("invert", {cond_datum}));
ARROW_ASSIGN_OR_RAISE(auto sel_if_false_datum,
CallFunction("indices_nonzero", {invert_cond_datum}));
DCHECK(sel_if_false_datum.is_array());
auto sel_if_false_array = sel_if_false_datum.make_array();

// 5. Take the "false" rows from input.
ARROW_ASSIGN_OR_RAISE(auto if_false_input_datum,
CallFunction("take", {input_rb, sel_if_false_datum}));

DCHECK_EQ(if_true_input_datum.kind(), Datum::RECORD_BATCH);
auto if_true_input_batch = ExecBatch(*if_true_input_datum.record_batch());

DCHECK_EQ(if_false_input_datum.kind(), Datum::RECORD_BATCH);
auto if_false_input_batch = ExecBatch(*if_false_input_datum.record_batch());

// 6. Evaluate "true" branch on the "true" rows.
ARROW_ASSIGN_OR_RAISE(auto bound_if_true, if_true.Bind(*schema));
ARROW_ASSIGN_OR_RAISE(auto if_true_result_datum,
ExecuteScalarExpression(bound_if_true, if_true_input_batch));
DCHECK(if_true_result_datum.is_array());
auto if_true_result_array = if_true_result_datum.make_array();

// 7. Evaluate "false" branch on the "false" rows.
ARROW_ASSIGN_OR_RAISE(auto bound_if_false, if_false.Bind(*schema));
ARROW_ASSIGN_OR_RAISE(auto if_false_result_datum,
ExecuteScalarExpression(bound_if_false, if_false_input_batch));
DCHECK(if_false_result_datum.is_array());
auto if_false_result_array = if_false_result_datum.make_array();

// 8. Combine the "true"/"false" results/selection vectors into chunked arrays.
auto result_ca = std::make_shared<ChunkedArray>(
ArrayVector{if_true_result_array, if_false_result_array});
auto sel_ca =
std::make_shared<ChunkedArray>(ArrayVector{sel_if_true_array, sel_if_false_array});

// 9. Finally, permute the "true"/"false" results to their original positions in the
// input (according to the selection vectors). Note we didn't handle the rows with nulls
// in the mask, because Permute will fill nulls for these rows and this is equal to the
// null handling policy of if-else, which is pretty nice.
return Permute(/*values=*/result_ca, /*indices=*/sel_ca,
/*output_length=*/input.length);
}

void DoTestIfElse(const Expression& cond, const Expression& if_true,
const Expression& if_false, const std::shared_ptr<Schema>& schema,
const ExecBatch& input) {
ASSERT_OK_AND_ASSIGN(Datum result_by_expr,
ExecuteIfElseByExpr(cond, if_true, if_false, schema, input));
ASSERT_TRUE(result_by_expr.is_array());
ASSERT_OK_AND_ASSIGN(Datum result_by_permute,
ExecuteIfElseByPermute(cond, if_true, if_false, schema, input));
ASSERT_TRUE(result_by_permute.is_chunked_array());
ASSERT_OK_AND_ASSIGN(auto result_by_permute_concat,
Concatenate(result_by_permute.chunked_array()->chunks()));

AssertDatumsEqual(result_by_expr, result_by_permute_concat);
}

void DoTestIfElse(const Expression& cond, const Expression& if_true,
const Expression& if_false, const std::shared_ptr<Schema>& schema,
const ExecBatch& input, const std::shared_ptr<Array>& expected) {
ASSERT_OK_AND_ASSIGN(Datum result,
ExecuteIfElseByPermute(cond, if_true, if_false, schema, input));
ASSERT_TRUE(result.is_chunked_array());
ASSERT_OK_AND_ASSIGN(auto result_concat, Concatenate(result.chunked_array()->chunks()));

AssertDatumsEqual(expected, result_concat);
}

} // namespace

TEST(Permute, IfElse) {
{
ARROW_SCOPED_TRACE("if (b != 0) then a / b else b");
auto cond = call("not_equal", {field_ref("b"), literal(0)});
auto if_true = call("divide", {field_ref("a"), field_ref("b")});
auto if_false = field_ref("b");
auto schema = arrow::schema({field("a", int32()), field("b", int32())});
{
auto rb = RecordBatchFromJSON(schema, R"([
[1, 1],
[2, 1],
[3, 0],
[4, 1],
[5, 1]
])");
auto input = ExecBatch(*rb);

ASSERT_RAISES_WITH_MESSAGE(
Invalid, "Invalid: divide by zero",
ExecuteIfElseByExpr(cond, if_true, if_false, schema, input));

auto expected = ArrayFromJSON(int32(), "[1, 2, 0, 4, 5]");
DoTestIfElse(cond, if_true, if_false, schema, input, expected);
}
}
{
ARROW_SCOPED_TRACE("if (a > b) then a else b");
auto cond = call("greater", {field_ref("a"), field_ref("b")});
auto if_true = field_ref("a");
auto if_false = field_ref("b");
constexpr int64_t length = 5;
for (const auto& type : kNumericTypes) {
ARROW_SCOPED_TRACE("Type " + type->ToString());
auto schema = arrow::schema({field("a", type), field("b", type)});
auto big = ArrayFromJSON(type, "[1, 2, 3, 4, 5]");
auto small = ArrayFromJSON(type, "[0, 1, 2, 3, 4]");
{
ARROW_SCOPED_TRACE("All true");
auto input =
ExecBatch(*RecordBatch::Make(schema, length, {/*a=*/big, /*b=*/small}));
DoTestIfElse(cond, if_true, if_false, schema, input);
}
{
ARROW_SCOPED_TRACE("All false");
auto input =
ExecBatch(*RecordBatch::Make(schema, length, {/*a=*/small, /*b=*/big}));
DoTestIfElse(cond, if_true, if_false, schema, input);
}
}
{
ARROW_SCOPED_TRACE("Random");
auto rng = random::RandomArrayGenerator(42);
constexpr int64_t length = 1024;
constexpr int repeat = 10;
for (const auto& type : kNumericAndBaseBinaryTypes) {
ARROW_SCOPED_TRACE("Type " + type->ToString());
auto schema = arrow::schema({field("a", type), field("b", type)});
for (int i = 0; i < repeat; ++i) {
auto a = rng.ArrayOf(type, length, /*null_probability=*/0.2);
auto b = rng.ArrayOf(type, length, /*null_probability=*/0.2);
auto input =
ExecBatch(*RecordBatch::Make(schema, length, {std::move(a), std::move(b)}));
DoTestIfElse(cond, if_true, if_false, schema, input);
}
}
}
}
}

} // namespace arrow::compute

0 comments on commit 0811b2b

Please sign in to comment.