Skip to content

Commit

Permalink
[refactor](function) imporve compoundPred optimization work with chil…
Browse files Browse the repository at this point in the history
…dren is nullable
  • Loading branch information
zhangstar333 committed Oct 31, 2023
1 parent f883d1a commit d3d1a36
Showing 1 changed file with 165 additions and 66 deletions.
231 changes: 165 additions & 66 deletions be/src/vec/exprs/vcompound_pred.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "vec/columns/column.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/data_types/data_type_number.h"
#include "vec/exprs/vectorized_fn_call.h"
#include "vec/exprs/vexpr.h"

Expand Down Expand Up @@ -52,86 +53,190 @@ class VCompoundPred : public VectorizedFnCall {
const std::string& expr_name() const override { return _expr_name; }

Status execute(VExprContext* context, Block* block, int* result_column_id) override {
if (children().size() == 1 || !_all_child_is_compound_and_not_const() ||
_children[0]->is_nullable() || _children[1]->is_nullable()) {
// TODO:
// When the child is nullable, make the optimization also take effect, and the processing of this piece may be more complicated
// https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html
if (children().size() == 1 || !_all_child_is_compound_and_not_const()) {
return VectorizedFnCall::execute(context, block, result_column_id);
}

int lhs_id = -1;
int rhs_id = -1;
RETURN_IF_ERROR(_children[0]->execute(context, block, &lhs_id));
ColumnPtr lhs_column =
block->get_by_position(lhs_id).column->convert_to_full_column_if_const();

//should deal with const column ?
ColumnPtr lhs_column = block->get_by_position(lhs_id).column;
size_t size = lhs_column->size();
uint8* __restrict data = _get_raw_data(lhs_column);
int filted = simd::count_zero_num((int8_t*)data, size);
bool full = filted == 0;
bool empty = filted == size;
bool lhs_is_nullable = lhs_column->is_nullable();
auto [lhs_data_column, lhs_null_map] =
_get_raw_data_and_null_map(lhs_column, lhs_is_nullable);
int filted = simd::count_zero_num((int8_t*)lhs_data_column, size);
bool lhs_all_true = (filted == 0);
bool lhs_all_false = (filted == size);

bool lhs_all_is_not_null = false;
if (lhs_is_nullable) {
filted = simd::count_zero_num((int8_t*)lhs_null_map, size);
lhs_all_is_not_null = (filted == size);
}

ColumnPtr rhs_column = nullptr;
uint8* __restrict data_rhs = nullptr;
bool full_rhs = false;
bool empty_rhs = false;
uint8* __restrict rhs_data_column = nullptr;
uint8* __restrict rhs_null_map = nullptr;
bool rhs_is_nullable = false;
bool rhs_all_true = false;
bool rhs_all_false = false;
bool rhs_all_is_not_null = false;

auto get_rhs_colum = [&]() {
if (rhs_id == -1) {
RETURN_IF_ERROR(_children[1]->execute(context, block, &rhs_id));
rhs_column =
block->get_by_position(rhs_id).column->convert_to_full_column_if_const();
data_rhs = _get_raw_data(rhs_column);
int filted = simd::count_zero_num((int8_t*)data_rhs, size);
full_rhs = filted == 0;
empty_rhs = filted == size;
rhs_column = block->get_by_position(rhs_id).column;
rhs_is_nullable = rhs_column->is_nullable();
auto rhs_nullable_column = _get_raw_data_and_null_map(rhs_column, rhs_is_nullable);
rhs_data_column = rhs_nullable_column.first;
rhs_null_map = rhs_nullable_column.second;
int filted = simd::count_zero_num((int8_t*)rhs_data_column, size);
rhs_all_true = (filted == 0);
rhs_all_false = (filted == size);
if (rhs_is_nullable) {
filted = simd::count_zero_num((int8_t*)rhs_null_map, size);
rhs_all_is_not_null = (filted == size);
}
}
return Status::OK();
};

// false and NULL ----> 0
// true and NULL ----> NULL
if (_op == TExprOpcode::COMPOUND_AND) {
if (empty) {
// empty and any = empty, return lhs
//not null column: all data is false
//nullable column: null map all is not null
if ((lhs_all_false && !lhs_is_nullable) || (lhs_all_false && lhs_all_is_not_null)) {
// false and any = false, return lhs
*result_column_id = lhs_id;
} else {
RETURN_IF_ERROR(get_rhs_colum());

if (full) {
// full and any = any, return rhs
if ((lhs_all_true && !lhs_is_nullable) || //not null column
(lhs_all_true && lhs_all_is_not_null)) { //nullable column
// true and any = any, return rhs
*result_column_id = rhs_id;
} else if (empty_rhs) {
// any and empty = empty, return rhs
} else if ((rhs_all_false && !rhs_is_nullable) ||
(rhs_all_false && rhs_all_is_not_null)) {
// any and false = false, return rhs
*result_column_id = rhs_id;
} else if (full_rhs) {
// any and full = any, return lhs
} else if ((rhs_all_true && !rhs_is_nullable) ||
(rhs_all_true && rhs_all_is_not_null)) {
// any and true = any, return lhs
*result_column_id = lhs_id;
} else {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
data[i] &= data_rhs[i];
bool res_nullable = (lhs_is_nullable || rhs_is_nullable);
if (!res_nullable) {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
lhs_data_column[i] &= rhs_data_column[i];
}
} else {
auto col_res = ColumnUInt8::create(size);
auto col_nulls = ColumnUInt8::create(size);
auto* __restrict res_datas =
assert_cast<ColumnUInt8*>(col_res)->get_data().data();
auto* __restrict res_nulls =
assert_cast<ColumnUInt8*>(col_nulls)->get_data().data();
ColumnPtr temp_null_map = nullptr;
if ((lhs_is_nullable && !rhs_is_nullable) ||
(!lhs_is_nullable || rhs_is_nullable)) { // one of children is nullable
if (lhs_null_map == nullptr) {
temp_null_map = ColumnUInt8::create(size, 0);
lhs_null_map = assert_cast<ColumnUInt8*>(
temp_null_map->assume_mutable().get())
->get_data()
.data();
}
if (rhs_null_map == nullptr) {
temp_null_map = ColumnUInt8::create(size, 0);
rhs_null_map = assert_cast<ColumnUInt8*>(
temp_null_map->assume_mutable().get())
->get_data()
.data();
}
}

for (size_t i = 0; i < size; ++i) {
res_nulls[i] =
(lhs_null_map[i] & rhs_null_map[i]) |
(rhs_null_map[i] & (lhs_null_map[i] ^ lhs_data_column[i])) |
(lhs_null_map[i] & (rhs_null_map[i] ^ rhs_data_column[i]));
res_datas[i] = lhs_data_column[i] & rhs_data_column[i];
}
auto result_column =
ColumnNullable::create(std::move(col_res), std::move(col_nulls));
auto result_type = make_nullable(std::make_shared<DataTypeUInt8>());
*result_column_id = block->columns();
block->insert({std::move(result_column), result_type, _expr_name});
}
}
}
} else if (_op == TExprOpcode::COMPOUND_OR) {
if (full) {
// full or any = full, return lhs
// true or NULL ----> 1
// false or NULL ----> NULL
if ((lhs_all_true && !lhs_is_nullable) || (lhs_all_true && lhs_all_is_not_null)) {
// true or any = true, return lhs
*result_column_id = lhs_id;
} else {
RETURN_IF_ERROR(get_rhs_colum());
if (empty) {
// empty or any = any, return rhs
if ((lhs_all_false && !lhs_is_nullable) || (lhs_all_false && lhs_all_is_not_null)) {
// false or any = any, return rhs
*result_column_id = rhs_id;
} else if (full_rhs) {
// any or full = full, return rhs
} else if ((rhs_all_true && !rhs_is_nullable) ||
(rhs_all_true && rhs_all_is_not_null)) {
// any or true = true, return rhs
*result_column_id = rhs_id;
} else if (empty_rhs) {
// any or empty = any, return lhs
} else if ((rhs_all_false && !rhs_is_nullable) ||
(rhs_all_false && rhs_all_is_not_null)) {
// any or false = any, return lhs
*result_column_id = lhs_id;
} else {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
data[i] |= data_rhs[i];
bool res_nullable = (lhs_is_nullable || rhs_is_nullable);
if (!res_nullable) {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
lhs_data_column[i] |= rhs_data_column[i];
}
} else {
auto col_res = ColumnUInt8::create(size);
auto col_nulls = ColumnUInt8::create(size);
auto* __restrict res_datas =
assert_cast<ColumnUInt8*>(col_res)->get_data().data();
auto* __restrict res_nulls =
assert_cast<ColumnUInt8*>(col_nulls)->get_data().data();
ColumnPtr temp_null_map = nullptr;
if ((lhs_is_nullable && !rhs_is_nullable) ||
(!lhs_is_nullable || rhs_is_nullable)) { // one of children is nullable
if (lhs_null_map == nullptr) {
temp_null_map = ColumnUInt8::create(size, 0);
lhs_null_map = assert_cast<ColumnUInt8*>(
temp_null_map->assume_mutable().get())
->get_data()
.data();
}
if (rhs_null_map == nullptr) {
temp_null_map = ColumnUInt8::create(size, 0);
rhs_null_map = assert_cast<ColumnUInt8*>(
temp_null_map->assume_mutable().get())
->get_data()
.data();
}
}

for (size_t i = 0; i < size; ++i) {
res_nulls[i] =
(lhs_null_map[i] & rhs_null_map[i]) |
(rhs_null_map[i] & (rhs_null_map[i] ^ lhs_data_column[i])) |
(lhs_null_map[i] & (lhs_null_map[i] ^ rhs_data_column[i]));
res_datas[i] = lhs_data_column[i] | rhs_data_column[i];
}
auto result_column =
ColumnNullable::create(std::move(col_res), std::move(col_nulls));
auto result_type = make_nullable(std::make_shared<DataTypeUInt8>());
*result_column_id = block->columns();
block->insert({std::move(result_column), result_type, _expr_name});
}
}
}
Expand All @@ -155,29 +260,23 @@ class VCompoundPred : public VectorizedFnCall {
return true;
}

uint8* _get_raw_data(ColumnPtr column) const {
if (column->is_nullable()) {
return assert_cast<ColumnUInt8*>(
assert_cast<ColumnNullable*>(column->assume_mutable().get())
->get_nested_column_ptr()
.get())
->get_data()
.data();
} else {
return assert_cast<ColumnUInt8*>(column->assume_mutable().get())->get_data().data();
}
}

uint8* _get_null_map(ColumnPtr column) const {
if (column->is_nullable()) {
return assert_cast<ColumnUInt8*>(
assert_cast<ColumnNullable*>(column->assume_mutable().get())
->get_null_map_column_ptr()
.get())
->get_data()
.data();
std::pair<uint8*, uint8*> _get_raw_data_and_null_map(ColumnPtr column,
bool nullable_column) const {
if (nullable_column) {
auto* nullable_column = assert_cast<ColumnNullable*>(column->assume_mutable().get());
auto* data_column =
assert_cast<ColumnUInt8*>(nullable_column->get_nested_column_ptr().get())
->get_data()
.data();
auto* null_map =
assert_cast<ColumnUInt8*>(nullable_column->get_null_map_column_ptr().get())
->get_data()
.data();
return std::make_pair(data_column, null_map);
} else {
return nullptr;
auto* data_column =
assert_cast<ColumnUInt8*>(column->assume_mutable().get())->get_data().data();
return std::make_pair(data_column, nullptr);
}
}

Expand Down

0 comments on commit d3d1a36

Please sign in to comment.