Skip to content

Commit

Permalink
X
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiang-hhhh committed Sep 11, 2024
1 parent eb4673f commit a30c982
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 20 deletions.
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/aggregation_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, tnode.agg_node.aggregate_functions[i],
tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy,
&evaluator));
tnode.agg_node.grouping_exprs.empty(), &evaluator));
_aggregate_evaluators.push_back(evaluator);
}

Expand Down
4 changes: 2 additions & 2 deletions be/src/pipeline/exec/analytic_source_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,11 @@ Status AnalyticSourceOperatorX::init(const TPlanNode& tnode, RuntimeState* state
RETURN_IF_ERROR(OperatorX<AnalyticLocalState>::init(tnode, state));
const TAnalyticNode& analytic_node = tnode.analytic_node;
size_t agg_size = analytic_node.analytic_functions.size();

for (int i = 0; i < agg_size; ++i) {
vectorized::AggFnEvaluator* evaluator = nullptr;
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, analytic_node.analytic_functions[i], {}, &evaluator));
//TODO: need to check the without_key flag
_pool, analytic_node.analytic_functions[i], {}, /*wihout_key*/ false, &evaluator));
_agg_functions.emplace_back(evaluator);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ Status DistinctStreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState*
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, tnode.agg_node.aggregate_functions[i],
tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy,
&evaluator));
/*wihout_key=*/false, &evaluator));
_aggregate_evaluators.push_back(evaluator);
}

Expand Down
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/streaming_aggregation_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, tnode.agg_node.aggregate_functions[i],
tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy,
&evaluator));
tnode.agg_node.grouping_exprs.empty(), &evaluator));
_aggregate_evaluators.push_back(evaluator);
}

Expand Down
90 changes: 86 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

#pragma once

#include <algorithm>

#include "common/exception.h"
#include "common/status.h"
#include "util/defer_op.h"
#include "vec/columns/column_complex.h"
#include "vec/columns/column_string.h"
Expand All @@ -30,6 +34,7 @@
#include "vec/core/column_numbers.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_string.h"

namespace doris::vectorized {
Expand Down Expand Up @@ -62,6 +67,59 @@ using ConstAggregateDataPtr = const char*;
} \
} while (0)

namespace agg_nullale_property {
struct AlwaysNullable {
static bool is_valid_nullable_property(const bool without_key,
const DataTypes& argument_types_with_nullable,
const DataTypePtr result_type_with_nullable) {
return result_type_with_nullable->is_nullable();
}
};

struct AlwaysNotNullable {
static bool is_valid_nullable_property(const bool without_key,
const DataTypes& argument_types_with_nullable,
const DataTypePtr result_type_with_nullable) {
return !result_type_with_nullable->is_nullable();
}
};

// PropograteNullable is deprecated after this pr: https://github.com/apache/doris/pull/37330
// No more PropograteNullable aggregate function, use NullableAggregateFunction instead
// We keep this struct since this on branch 2.1.x, many aggregate functions on FE are still PropograteNullable.
struct PropograteNullable {
static bool is_valid_nullable_property(const bool without_key,
const DataTypes& argument_types_with_nullable,
const DataTypePtr result_type_with_nullable) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"PropograteNullable should not used after version 2.1.x");
}
};

struct NullableAggregateFunction {
static bool is_valid_nullable_property(const bool without_key,
const DataTypes& argument_types_with_nullable,
const DataTypePtr result_type_with_nullable) {
if (std::any_of(argument_types_with_nullable.begin(), argument_types_with_nullable.end(),
[](const DataTypePtr& type) { return type->is_nullable(); }) &&
result_type_with_nullable->is_nullable()) {
// One of input arguments is nullable, the result must be nullable.
return result_type_with_nullable->is_nullable();
} else {
// All column is not nullable, the result can be nullable or not.
// Depends on whether executed with group by.
if (without_key) {
// If without key, means agg is executed without group by, the result must be nullable.
return result_type_with_nullable->is_nullable();
} else {
// If not without key, means agg is executed with group by, the result must be not nullable.
return !result_type_with_nullable->is_nullable();
}
}
}
};
}; // namespace agg_nullale_property

/** Aggregate functions interface.
* Instances of classes with this interface do not contain the data itself for aggregation,
* but contain only metadata (description) of the aggregate function,
Expand All @@ -80,6 +138,10 @@ class IAggregateFunction {
/// Get the result type.
virtual DataTypePtr get_return_type() const = 0;

/// Varify function signature
virtual bool is_valid_signature(const bool without_key, const DataTypes& argument_types,
const DataTypePtr result_type) const = 0;

virtual ~IAggregateFunction() = default;

/** Create empty data for aggregation with `placement new` at the specified location.
Expand Down Expand Up @@ -228,12 +290,31 @@ class IAggregateFunction {
};

/// Implement method to obtain an address of 'add' function.
template <typename Derived>
template <typename Derived,
typename NullablePropertyArg = agg_nullale_property::NullableAggregateFunction>
class IAggregateFunctionHelper : public IAggregateFunction {
public:
IAggregateFunctionHelper(const DataTypes& argument_types_)
: IAggregateFunction(argument_types_) {}

using NullableProperty = NullablePropertyArg;

bool is_valid_signature(const bool without_key, const DataTypes& argument_types_with_nullable,
const DataTypePtr result_type_with_nullable) const override {
if (NullableProperty::is_valid_nullable_property(without_key, argument_types_with_nullable,
result_type_with_nullable)) {
return is_valid_signature_impl(remove_nullable(argument_types_with_nullable),
remove_nullable(result_type_with_nullable));
} else {
return false;
}
}

virtual bool is_valid_signature_impl(const DataTypes& argument_types_without_nullable,
const DataTypePtr result_type_without_nullable) const {
return false;
}

void destroy_vec(AggregateDataPtr __restrict place,
const size_t num_rows) const noexcept override {
const size_t size_of_data_ = size_of_data();
Expand Down Expand Up @@ -497,8 +578,9 @@ class IAggregateFunctionHelper : public IAggregateFunction {
};

/// Implements several methods for manipulation with data. T - type of structure with data for aggregation.
template <typename T, typename Derived>
class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived> {
template <typename T, typename Derived,
typename NullableProperty = agg_nullale_property::NullableAggregateFunction>
class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived, NullableProperty> {
protected:
using Data = T;

Expand All @@ -509,7 +591,7 @@ class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived> {

public:
IAggregateFunctionDataHelper(const DataTypes& argument_types_)
: IAggregateFunctionHelper<Derived>(argument_types_) {}
: IAggregateFunctionHelper<Derived, NullableProperty>(argument_types_) {}

void create(AggregateDataPtr __restrict place) const override { new (place) Data; }

Expand Down
22 changes: 22 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,28 @@ class AggregateFunctionAvg final
}
}

bool is_valid_signature_impl(const DataTypes& argument_types_with_nullable,
DataTypePtr result_type_with_nullable) const override {
if (argument_types_with_nullable.size() != 1) {
return false;
}

if (is_integer(argument_types_with_nullable[0]) ||
is_float(argument_types_with_nullable[0])) {
return result_type_with_nullable->get_type_id() == TypeIndex::Float64;
}

if (is_decimal_v2(argument_types_with_nullable[0])) {
return is_decimal_v2(result_type_with_nullable);
}

if (is_decimal(argument_types_with_nullable[0])) {
return is_decimal(result_type_with_nullable);
}

return false;
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
#ifdef __clang__
Expand Down
37 changes: 34 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/common/string_ref.h"
#include "vec/common/typeid_cast.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_string.h"
#include "vec/io/io_helper.h"
#include "vec/utils/histogram_helpers.hpp"
Expand Down Expand Up @@ -175,22 +177,51 @@ struct AggregateFunctionHistogramData {

template <typename Data, typename T, bool has_input_param>
class AggregateFunctionHistogram final
: public IAggregateFunctionDataHelper<
Data, AggregateFunctionHistogram<Data, T, has_input_param>> {
: public IAggregateFunctionDataHelper<Data,
AggregateFunctionHistogram<Data, T, has_input_param>,
agg_nullale_property::AlwaysNotNullable> {
public:
using ColVecType = ColumnVectorOrDecimal<T>;

AggregateFunctionHistogram() = default;
AggregateFunctionHistogram(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data,
AggregateFunctionHistogram<Data, T, has_input_param>>(
AggregateFunctionHistogram<Data, T, has_input_param>,
agg_nullale_property::AlwaysNotNullable>(
argument_types_),
_argument_type(argument_types_[0]) {}

std::string get_name() const override { return "histogram"; }

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeString>(); }

bool is_valid_signature_impl(const DataTypes& argument_types_with_nullable,
DataTypePtr result_type_with_nullable) const override {
if (result_type_with_nullable->get_type_id() != TypeIndex::String) {
return false;
}

// According to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java
// histogram function supports AnyDataType as its first argument.

if (argument_types_with_nullable.size() == 2) {
if (argument_types_with_nullable[1]->get_type_id() != TypeIndex::Int32) {
return false;
}
}

if (argument_types_with_nullable.size() == 3) {
if (argument_types_with_nullable[1]->get_type_id() != TypeIndex::Float64) {
return false;
}
if (argument_types_with_nullable[2]->get_type_id() != TypeIndex::Int32) {
return false;
}
}

return true;
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
if constexpr (has_input_param) {
Expand Down
8 changes: 5 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_null.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
namespace doris::vectorized {

template <typename NestFunction, bool result_is_nullable, typename Derived>
class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived> {
class AggregateFunctionNullBaseInline
: public IAggregateFunctionHelper<Derived, typename NestFunction::NullableProperty> {
protected:
std::unique_ptr<NestFunction> nested_function;
size_t prefix_size;
Expand Down Expand Up @@ -72,7 +73,7 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived>
public:
AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_,
const DataTypes& arguments)
: IAggregateFunctionHelper<Derived>(arguments),
: IAggregateFunctionHelper<Derived, typename NestFunction::NullableProperty>(arguments),
nested_function {assert_cast<NestFunction*>(nested_function_)} {
if (result_is_nullable) {
prefix_size = nested_function->align_of_data();
Expand All @@ -82,7 +83,8 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived>
}

void set_version(const int version_) override {
IAggregateFunctionHelper<Derived>::set_version(version_);
IAggregateFunctionHelper<Derived, typename NestFunction::NullableProperty>::set_version(
version_);
nested_function->set_version(version_);
}

Expand Down
14 changes: 11 additions & 3 deletions be/src/vec/exprs/vectorized_agg_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ AggregateFunctionPtr get_agg_state_function(const DataTypes& argument_types,
argument_types, return_type);
}

AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, const bool without_key)
: _fn(desc.fn),
_is_merge(desc.agg_expr.is_merge_agg),
_without_key(without_key),
_return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)) {
bool nullable = true;
if (desc.__isset.is_nullable) {
Expand All @@ -83,8 +84,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
}

Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
AggFnEvaluator** result) {
*result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0]).release());
const bool without_key, AggFnEvaluator** result) {
*result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0], without_key).release());
auto& agg_fn_evaluator = *result;
int node_idx = 0;
for (int i = 0; i < desc.nodes[0].num_children; ++i) {
Expand Down Expand Up @@ -213,6 +214,12 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
_function = transform_to_sort_agg_function(_function, _argument_types_with_sort,
_sort_description, state);
}

if (!_function->is_valid_signature(_without_key, argument_types, _data_type)) {
return Status::InvalidArgument("Argument of aggregate function {} is invalid.",
_fn.name.function_name);
}

_expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name);
return Status::OK();
}
Expand Down Expand Up @@ -320,6 +327,7 @@ AggFnEvaluator* AggFnEvaluator::clone(RuntimeState* state, ObjectPool* pool) {
AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state)
: _fn(evaluator._fn),
_is_merge(evaluator._is_merge),
_without_key(evaluator._without_key),
_argument_types_with_sort(evaluator._argument_types_with_sort),
_real_argument_types(evaluator._real_argument_types),
_return_type(evaluator._return_type),
Expand Down
6 changes: 4 additions & 2 deletions be/src/vec/exprs/vectorized_agg_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AggFnEvaluator {

public:
static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
AggFnEvaluator** result);
const bool without_key, AggFnEvaluator** result);

Status prepare(RuntimeState* state, const RowDescriptor& desc,
const SlotDescriptor* intermediate_slot_desc,
Expand Down Expand Up @@ -109,8 +109,10 @@ class AggFnEvaluator {
const TFunction _fn;

const bool _is_merge;
// We need this flag to distinguish between the two types of aggregation functions:
const bool _without_key;

AggFnEvaluator(const TExprNode& desc);
AggFnEvaluator(const TExprNode& desc, const bool without_key);
AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state);

Status _calc_argument_columns(Block* block);
Expand Down

0 comments on commit a30c982

Please sign in to comment.