Skip to content

Commit

Permalink
[opt](assert_cast) Make assert cast do type check in release build by…
Browse files Browse the repository at this point in the history
… default (#39030)

* Problem to solve

We encountered many issues that ultimately proved to be caused by memory
insecurity. These issues are hard to solve, and the final crash log
maybe not related to the root problem.

* Fix

We make `assert_cast` do type check in release build by default. And we
can use a template arg `TypeCheckOnRelease::DISABLE` to disable type
check in release build.
`TypeCheckOnRelease::DISABLE` should be used when user agrees that this
function will be called many many times (eg. add method of
AggregatedData, which will be called by rows) or you think type safe has
already been guaranteed (eg. `assert_cast<const Derived*,
TypeCheckOnRelease::DISABLE>(this)`.
  • Loading branch information
zhiqiang-hhhh authored Aug 13, 2024
1 parent 7a3f3b6 commit aa2929e
Show file tree
Hide file tree
Showing 67 changed files with 632 additions and 325 deletions.
4 changes: 3 additions & 1 deletion be/src/olap/base_tablet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "util/crc32c.h"
#include "util/debug_points.h"
#include "util/doris_metrics.h"
#include "vec/common/assert_cast.h"
#include "vec/common/schema_util.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/jsonb/serialize.h"
Expand Down Expand Up @@ -1030,7 +1031,8 @@ Status BaseTablet::generate_new_block_for_partial_update(
if (rs_column.has_default_value()) {
mutable_column->insert_from(*mutable_default_value_columns[i].get(), 0);
} else if (rs_column.is_nullable()) {
assert_cast<vectorized::ColumnNullable*>(mutable_column.get())
assert_cast<vectorized::ColumnNullable*, TypeCheckOnRelease::DISABLE>(
mutable_column.get())
->insert_null_elements(1);
} else {
mutable_column->insert_default();
Expand Down
1 change: 1 addition & 0 deletions be/src/olap/bloom_filter_predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/columns/predicate_column.h"
#include "vec/common/assert_cast.h"
#include "vec/exprs/vruntimefilter_wrapper.h"

namespace doris {
Expand Down
109 changes: 63 additions & 46 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,16 @@ class IAggregateFunctionHelper : public IAggregateFunction {
void destroy_vec(AggregateDataPtr __restrict place,
const size_t num_rows) const noexcept override {
const size_t size_of_data_ = size_of_data();
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = 0; i != num_rows; ++i) {
assert_cast<const Derived*>(this)->destroy(place + size_of_data_ * i);
derived->destroy(place + size_of_data_ * i);
}
}

void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena, bool agg_many) const override {
const Derived* derived = assert_cast<const Derived*>(this);

if constexpr (std::is_same_v<Derived, AggregateFunctionBitmapCount<false, ColumnBitmap>> ||
std::is_same_v<Derived, AggregateFunctionBitmapCount<true, ColumnBitmap>> ||
std::is_same_v<Derived,
Expand All @@ -262,64 +265,69 @@ class IAggregateFunctionHelper : public IAggregateFunction {
}
auto iter = place_rows.begin();
while (iter != place_rows.end()) {
assert_cast<const Derived*>(this)->add_many(iter->first, columns, iter->second,
arena);
derived->add_many(iter->first, columns, iter->second, arena);
iter++;
}
return;
}
}

for (size_t i = 0; i < batch_size; ++i) {
assert_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena);
derived->add(places[i] + place_offset, columns, i, arena);
}
}

void add_batch_selected(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena) const override {
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = 0; i < batch_size; ++i) {
if (places[i]) {
assert_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena);
derived->add(places[i] + place_offset, columns, i, arena);
}
}
}

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = 0; i < batch_size; ++i) {
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
derived->add(place, columns, i, arena);
}
}
//now this is use for sum/count/avg/min/max win function, other win function should override this function in class
//stddev_pop/stddev_samp/variance_pop/variance_samp/hll_union_agg/group_concat
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
const Derived* derived = assert_cast<const Derived*>(this);
frame_start = std::max<int64_t>(frame_start, partition_start);
frame_end = std::min<int64_t>(frame_end, partition_end);
for (int64_t i = frame_start; i < frame_end; ++i) {
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
derived->add(place, columns, i, arena);
}
}

void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena* arena, bool has_null) override {
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = batch_begin; i <= batch_end; ++i) {
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
derived->add(place, columns, i, arena);
}
}

void insert_result_into_vec(const std::vector<AggregateDataPtr>& places, const size_t offset,
IColumn& to, const size_t num_rows) const override {
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = 0; i != num_rows; ++i) {
assert_cast<const Derived*>(this)->insert_result_into(places[i] + offset, to);
derived->insert_result_into(places[i] + offset, to);
}
}

void serialize_vec(const std::vector<AggregateDataPtr>& places, size_t offset,
BufferWritable& buf, const size_t num_rows) const override {
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = 0; i != num_rows; ++i) {
assert_cast<const Derived*>(this)->serialize(places[i] + offset, buf);
derived->serialize(places[i] + offset, buf);
buf.commit();
}
}
Expand All @@ -333,11 +341,12 @@ class IAggregateFunctionHelper : public IAggregateFunction {
void streaming_agg_serialize(const IColumn** columns, BufferWritable& buf,
const size_t num_rows, Arena* arena) const override {
std::vector<char> place(size_of_data());
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = 0; i != num_rows; ++i) {
assert_cast<const Derived*>(this)->create(place.data());
DEFER({ assert_cast<const Derived*>(this)->destroy(place.data()); });
assert_cast<const Derived*>(this)->add(place.data(), columns, i, arena);
assert_cast<const Derived*>(this)->serialize(place.data(), buf);
derived->create(place.data());
DEFER({ derived->destroy(place.data()); });
derived->add(place.data(), columns, i, arena);
derived->serialize(place.data(), buf);
buf.commit();
}
}
Expand All @@ -357,17 +366,18 @@ class IAggregateFunctionHelper : public IAggregateFunction {

void deserialize_vec(AggregateDataPtr places, const ColumnString* column, Arena* arena,
size_t num_rows) const override {
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
const Derived* derived = assert_cast<const Derived*>(this);
const auto size_of_data = derived->size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
try {
auto place = places + size_of_data * i;
VectorBufferReader buffer_reader(column->get_data_at(i));
assert_cast<const Derived*>(this)->create(place);
assert_cast<const Derived*>(this)->deserialize(place, buffer_reader, arena);
derived->create(place);
derived->deserialize(place, buffer_reader, arena);
} catch (...) {
for (int j = 0; j < i; ++j) {
auto place = places + size_of_data * j;
assert_cast<const Derived*>(this)->destroy(place);
derived->destroy(place);
}
throw;
}
Expand All @@ -377,49 +387,52 @@ class IAggregateFunctionHelper : public IAggregateFunction {
void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
AggregateDataPtr rhs, const IColumn* column, Arena* arena,
const size_t num_rows) const override {
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
const Derived* derived = assert_cast<const Derived*>(this);
const auto size_of_data = derived->size_of_data();
const auto* column_string = assert_cast<const ColumnString*>(column);

for (size_t i = 0; i != num_rows; ++i) {
try {
auto rhs_place = rhs + size_of_data * i;
VectorBufferReader buffer_reader(column_string->get_data_at(i));
assert_cast<const Derived*>(this)->create(rhs_place);
assert_cast<const Derived*>(this)->deserialize_and_merge(
places[i] + offset, rhs_place, buffer_reader, arena);
derived->create(rhs_place);
derived->deserialize_and_merge(places[i] + offset, rhs_place, buffer_reader, arena);
} catch (...) {
for (int j = 0; j < i; ++j) {
auto place = rhs + size_of_data * j;
assert_cast<const Derived*>(this)->destroy(place);
derived->destroy(place);
}
throw;
}
}
assert_cast<const Derived*>(this)->destroy_vec(rhs, num_rows);

derived->destroy_vec(rhs, num_rows);
}

void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
AggregateDataPtr rhs, const IColumn* column,
Arena* arena, const size_t num_rows) const override {
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
const auto* derived = assert_cast<const Derived*>(this);
const auto size_of_data = derived->size_of_data();
const auto* column_string = assert_cast<const ColumnString*>(column);
for (size_t i = 0; i != num_rows; ++i) {
try {
auto rhs_place = rhs + size_of_data * i;
VectorBufferReader buffer_reader(column_string->get_data_at(i));
assert_cast<const Derived*>(this)->create(rhs_place);
derived->create(rhs_place);
if (places[i]) {
assert_cast<const Derived*>(this)->deserialize_and_merge(
places[i] + offset, rhs_place, buffer_reader, arena);
derived->deserialize_and_merge(places[i] + offset, rhs_place, buffer_reader,
arena);
}
} catch (...) {
for (int j = 0; j < i; ++j) {
auto place = rhs + size_of_data * j;
assert_cast<const Derived*>(this)->destroy(place);
derived->destroy(place);
}
throw;
}
}
assert_cast<const Derived*>(this)->destroy_vec(rhs, num_rows);
derived->destroy_vec(rhs, num_rows);
}

void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena* arena,
Expand All @@ -429,21 +442,21 @@ class IAggregateFunctionHelper : public IAggregateFunction {

void merge_vec(const AggregateDataPtr* places, size_t offset, ConstAggregateDataPtr rhs,
Arena* arena, const size_t num_rows) const override {
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
const auto* derived = assert_cast<const Derived*>(this);
const auto size_of_data = derived->size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
assert_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i,
arena);
derived->merge(places[i] + offset, rhs + size_of_data * i, arena);
}
}

void merge_vec_selected(const AggregateDataPtr* places, size_t offset,
ConstAggregateDataPtr rhs, Arena* arena,
const size_t num_rows) const override {
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
const auto* derived = assert_cast<const Derived*>(this);
const auto size_of_data = derived->size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
if (places[i]) {
assert_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i,
arena);
derived->merge(places[i] + offset, rhs + size_of_data * i, arena);
}
}
}
Expand All @@ -455,13 +468,15 @@ class IAggregateFunctionHelper : public IAggregateFunction {
<< ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
std::vector<char> deserialized_data(size_of_data());
auto* deserialized_place = (AggregateDataPtr)deserialized_data.data();
const ColumnString& column_string = assert_cast<const ColumnString&>(column);
const Derived* derived = assert_cast<const Derived*>(this);
for (size_t i = begin; i <= end; ++i) {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(column)).get_data_at(i));
assert_cast<const Derived*>(this)->create(deserialized_place);
DEFER({ assert_cast<const Derived*>(this)->destroy(deserialized_place); });
assert_cast<const Derived*>(this)->deserialize_and_merge(place, deserialized_place,
buffer_reader, arena);
VectorBufferReader buffer_reader(column_string.get_data_at(i));
derived->create(deserialized_place);

DEFER({ derived->destroy(deserialized_place); });

derived->deserialize_and_merge(place, deserialized_place, buffer_reader, arena);
}
}

Expand All @@ -475,8 +490,9 @@ class IAggregateFunctionHelper : public IAggregateFunction {

void deserialize_and_merge(AggregateDataPtr __restrict place, AggregateDataPtr __restrict rhs,
BufferReadable& buf, Arena* arena) const override {
assert_cast<const Derived*>(this)->deserialize(rhs, buf, arena);
assert_cast<const Derived*>(this)->merge(place, rhs, arena);
assert_cast<const Derived*, TypeCheckOnRelease::DISABLE>(this)->deserialize(rhs, buf,
arena);
assert_cast<const Derived*, TypeCheckOnRelease::DISABLE>(this)->merge(place, rhs, arena);
}
};

Expand Down Expand Up @@ -513,8 +529,9 @@ class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived> {

void deserialize_and_merge(AggregateDataPtr __restrict place, AggregateDataPtr __restrict rhs,
BufferReadable& buf, Arena* arena) const override {
assert_cast<const Derived*>(this)->deserialize(rhs, buf, arena);
assert_cast<const Derived*>(this)->merge(place, rhs, arena);
assert_cast<const Derived*, TypeCheckOnRelease::DISABLE>(this)->deserialize(rhs, buf,
arena);
assert_cast<const Derived*, TypeCheckOnRelease::DISABLE>(this)->merge(place, rhs, arena);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/common/string_ref.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_number.h"
Expand Down Expand Up @@ -98,12 +99,14 @@ class AggregateFunctionApproxCountDistinct final
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (IsFixLenColumnType<ColumnDataType>::value) {
auto column = assert_cast<const ColumnDataType*>(columns[0]);
auto column =
assert_cast<const ColumnDataType*, TypeCheckOnRelease::DISABLE>(columns[0]);
auto value = column->get_element(row_num);
this->data(place).add(
HashUtil::murmur_hash64A((char*)&value, sizeof(value), HashUtil::MURMUR_SEED));
} else {
auto value = assert_cast<const ColumnDataType*>(columns[0])->get_data_at(row_num);
auto value = assert_cast<const ColumnDataType*, TypeCheckOnRelease::DISABLE>(columns[0])
->get_data_at(row_num);
uint64_t hash_value =
HashUtil::murmur_hash64A(value.data, value.size, HashUtil::MURMUR_SEED);
this->data(place).add(hash_value);
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ class AggregateFunctionAvg final
#ifdef __clang__
#pragma clang fp reassociate(on)
#endif
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
const auto& column =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
if constexpr (IsDecimalNumber<T>) {
this->data(place).sum += column.get_data()[row_num].value;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class AggregateFunctionAvgWeight final

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
const auto& weight = assert_cast<const ColumnFloat64&>(*columns[1]);
const auto& column =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& weight =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
this->data(place).add(column.get_data()[row_num], weight.get_element(row_num));
}

Expand Down
4 changes: 3 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_bit.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <memory>

#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/common/assert_cast.h"
#include "vec/core/types.h"
#include "vec/io/io_helper.h"

Expand Down Expand Up @@ -114,7 +115,8 @@ class AggregateFunctionBitwise final

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColumnVector<T>&>(*columns[0]);
const auto& column =
assert_cast<const ColumnVector<T>&, TypeCheckOnRelease::DISABLE>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
}

Expand Down
Loading

0 comments on commit aa2929e

Please sign in to comment.