From a30c9827bb1cf91a0468b3ab3d666c09daeaa78d Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Wed, 11 Sep 2024 23:26:22 +0800 Subject: [PATCH] X --- .../exec/aggregation_sink_operator.cpp | 2 +- .../exec/analytic_source_operator.cpp | 4 +- ...istinct_streaming_aggregation_operator.cpp | 2 +- .../exec/streaming_aggregation_operator.cpp | 2 +- .../aggregate_functions/aggregate_function.h | 90 ++++++++++++++++++- .../aggregate_function_avg.h | 22 +++++ .../aggregate_function_histogram.h | 37 +++++++- .../aggregate_function_null.h | 8 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 14 ++- be/src/vec/exprs/vectorized_agg_fn.h | 6 +- 10 files changed, 167 insertions(+), 20 deletions(-) diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp b/be/src/pipeline/exec/aggregation_sink_operator.cpp index 260a599a947a0de..7d1c5ea09ee7e3f 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.cpp +++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp @@ -742,7 +742,7 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( _pool, tnode.agg_node.aggregate_functions[i], tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy, - &evaluator)); + tnode.agg_node.grouping_exprs.empty(), &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/pipeline/exec/analytic_source_operator.cpp b/be/src/pipeline/exec/analytic_source_operator.cpp index b521a9b583fa94c..64e49c20c463bd4 100644 --- a/be/src/pipeline/exec/analytic_source_operator.cpp +++ b/be/src/pipeline/exec/analytic_source_operator.cpp @@ -497,11 +497,11 @@ Status AnalyticSourceOperatorX::init(const TPlanNode& tnode, RuntimeState* state RETURN_IF_ERROR(OperatorX::init(tnode, state)); const TAnalyticNode& analytic_node = tnode.analytic_node; size_t agg_size = analytic_node.analytic_functions.size(); - for (int i = 0; i < agg_size; ++i) { vectorized::AggFnEvaluator* evaluator = nullptr; RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( - _pool, analytic_node.analytic_functions[i], {}, &evaluator)); + //TODO: need to check the without_key flag + _pool, analytic_node.analytic_functions[i], {}, /*wihout_key*/ false, &evaluator)); _agg_functions.emplace_back(evaluator); } diff --git a/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp b/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp index 5127605097f4c5f..6ae1f4bf2d424ad 100644 --- a/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp @@ -361,7 +361,7 @@ Status DistinctStreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( _pool, tnode.agg_node.aggregate_functions[i], tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy, - &evaluator)); + /*wihout_key=*/false, &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp b/be/src/pipeline/exec/streaming_aggregation_operator.cpp index dfbe42c637ea568..8afbbc4ed2927c4 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp @@ -1156,7 +1156,7 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( _pool, tnode.agg_node.aggregate_functions[i], tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy, - &evaluator)); + tnode.agg_node.grouping_exprs.empty(), &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 12d629b42c89f8e..236bc8ee9f430d0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -20,6 +20,10 @@ #pragma once +#include + +#include "common/exception.h" +#include "common/status.h" #include "util/defer_op.h" #include "vec/columns/column_complex.h" #include "vec/columns/column_string.h" @@ -30,6 +34,7 @@ #include "vec/core/column_numbers.h" #include "vec/core/field.h" #include "vec/core/types.h" +#include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_string.h" namespace doris::vectorized { @@ -62,6 +67,59 @@ using ConstAggregateDataPtr = const char*; } \ } while (0) +namespace agg_nullale_property { +struct AlwaysNullable { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable) { + return result_type_with_nullable->is_nullable(); + } +}; + +struct AlwaysNotNullable { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable) { + return !result_type_with_nullable->is_nullable(); + } +}; + +// PropograteNullable is deprecated after this pr: https://github.com/apache/doris/pull/37330 +// No more PropograteNullable aggregate function, use NullableAggregateFunction instead +// We keep this struct since this on branch 2.1.x, many aggregate functions on FE are still PropograteNullable. +struct PropograteNullable { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "PropograteNullable should not used after version 2.1.x"); + } +}; + +struct NullableAggregateFunction { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable) { + if (std::any_of(argument_types_with_nullable.begin(), argument_types_with_nullable.end(), + [](const DataTypePtr& type) { return type->is_nullable(); }) && + result_type_with_nullable->is_nullable()) { + // One of input arguments is nullable, the result must be nullable. + return result_type_with_nullable->is_nullable(); + } else { + // All column is not nullable, the result can be nullable or not. + // Depends on whether executed with group by. + if (without_key) { + // If without key, means agg is executed without group by, the result must be nullable. + return result_type_with_nullable->is_nullable(); + } else { + // If not without key, means agg is executed with group by, the result must be not nullable. + return !result_type_with_nullable->is_nullable(); + } + } + } +}; +}; // namespace agg_nullale_property + /** Aggregate functions interface. * Instances of classes with this interface do not contain the data itself for aggregation, * but contain only metadata (description) of the aggregate function, @@ -80,6 +138,10 @@ class IAggregateFunction { /// Get the result type. virtual DataTypePtr get_return_type() const = 0; + /// Varify function signature + virtual bool is_valid_signature(const bool without_key, const DataTypes& argument_types, + const DataTypePtr result_type) const = 0; + virtual ~IAggregateFunction() = default; /** Create empty data for aggregation with `placement new` at the specified location. @@ -228,12 +290,31 @@ class IAggregateFunction { }; /// Implement method to obtain an address of 'add' function. -template +template class IAggregateFunctionHelper : public IAggregateFunction { public: IAggregateFunctionHelper(const DataTypes& argument_types_) : IAggregateFunction(argument_types_) {} + using NullableProperty = NullablePropertyArg; + + bool is_valid_signature(const bool without_key, const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable) const override { + if (NullableProperty::is_valid_nullable_property(without_key, argument_types_with_nullable, + result_type_with_nullable)) { + return is_valid_signature_impl(remove_nullable(argument_types_with_nullable), + remove_nullable(result_type_with_nullable)); + } else { + return false; + } + } + + virtual bool is_valid_signature_impl(const DataTypes& argument_types_without_nullable, + const DataTypePtr result_type_without_nullable) const { + return false; + } + void destroy_vec(AggregateDataPtr __restrict place, const size_t num_rows) const noexcept override { const size_t size_of_data_ = size_of_data(); @@ -497,8 +578,9 @@ class IAggregateFunctionHelper : public IAggregateFunction { }; /// Implements several methods for manipulation with data. T - type of structure with data for aggregation. -template -class IAggregateFunctionDataHelper : public IAggregateFunctionHelper { +template +class IAggregateFunctionDataHelper : public IAggregateFunctionHelper { protected: using Data = T; @@ -509,7 +591,7 @@ class IAggregateFunctionDataHelper : public IAggregateFunctionHelper { public: IAggregateFunctionDataHelper(const DataTypes& argument_types_) - : IAggregateFunctionHelper(argument_types_) {} + : IAggregateFunctionHelper(argument_types_) {} void create(AggregateDataPtr __restrict place) const override { new (place) Data; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index 8a18a88839b4db4..fe2e0df33759a51 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -140,6 +140,28 @@ class AggregateFunctionAvg final } } + bool is_valid_signature_impl(const DataTypes& argument_types_with_nullable, + DataTypePtr result_type_with_nullable) const override { + if (argument_types_with_nullable.size() != 1) { + return false; + } + + if (is_integer(argument_types_with_nullable[0]) || + is_float(argument_types_with_nullable[0])) { + return result_type_with_nullable->get_type_id() == TypeIndex::Float64; + } + + if (is_decimal_v2(argument_types_with_nullable[0])) { + return is_decimal_v2(result_type_with_nullable); + } + + if (is_decimal(argument_types_with_nullable[0])) { + return is_decimal(result_type_with_nullable); + } + + return false; + } + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { #ifdef __clang__ diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.h b/be/src/vec/aggregate_functions/aggregate_function_histogram.h index 25fc6957321586e..f69e347b241eead 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.h +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.h @@ -38,7 +38,9 @@ #include "vec/columns/columns_number.h" #include "vec/common/assert_cast.h" #include "vec/common/string_ref.h" +#include "vec/common/typeid_cast.h" #include "vec/core/types.h" +#include "vec/data_types/data_type.h" #include "vec/data_types/data_type_string.h" #include "vec/io/io_helper.h" #include "vec/utils/histogram_helpers.hpp" @@ -175,15 +177,17 @@ struct AggregateFunctionHistogramData { template class AggregateFunctionHistogram final - : public IAggregateFunctionDataHelper< - Data, AggregateFunctionHistogram> { + : public IAggregateFunctionDataHelper, + agg_nullale_property::AlwaysNotNullable> { public: using ColVecType = ColumnVectorOrDecimal; AggregateFunctionHistogram() = default; AggregateFunctionHistogram(const DataTypes& argument_types_) : IAggregateFunctionDataHelper>( + AggregateFunctionHistogram, + agg_nullale_property::AlwaysNotNullable>( argument_types_), _argument_type(argument_types_[0]) {} @@ -191,6 +195,33 @@ class AggregateFunctionHistogram final DataTypePtr get_return_type() const override { return std::make_shared(); } + bool is_valid_signature_impl(const DataTypes& argument_types_with_nullable, + DataTypePtr result_type_with_nullable) const override { + if (result_type_with_nullable->get_type_id() != TypeIndex::String) { + return false; + } + + // According to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java + // histogram function supports AnyDataType as its first argument. + + if (argument_types_with_nullable.size() == 2) { + if (argument_types_with_nullable[1]->get_type_id() != TypeIndex::Int32) { + return false; + } + } + + if (argument_types_with_nullable.size() == 3) { + if (argument_types_with_nullable[1]->get_type_id() != TypeIndex::Float64) { + return false; + } + if (argument_types_with_nullable[2]->get_type_id() != TypeIndex::Int32) { + return false; + } + } + + return true; + } + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { if constexpr (has_input_param) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h index 382fb8f7a5310ee..7a03480d4284bdc 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.h +++ b/be/src/vec/aggregate_functions/aggregate_function_null.h @@ -33,7 +33,8 @@ namespace doris::vectorized { template -class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper { +class AggregateFunctionNullBaseInline + : public IAggregateFunctionHelper { protected: std::unique_ptr nested_function; size_t prefix_size; @@ -72,7 +73,7 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper public: AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_, const DataTypes& arguments) - : IAggregateFunctionHelper(arguments), + : IAggregateFunctionHelper(arguments), nested_function {assert_cast(nested_function_)} { if (result_is_nullable) { prefix_size = nested_function->align_of_data(); @@ -82,7 +83,8 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper } void set_version(const int version_) override { - IAggregateFunctionHelper::set_version(version_); + IAggregateFunctionHelper::set_version( + version_); nested_function->set_version(version_); } diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index c96d84db16c89c7..5802fb9ff7df6d6 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -63,9 +63,10 @@ AggregateFunctionPtr get_agg_state_function(const DataTypes& argument_types, argument_types, return_type); } -AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) +AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, const bool without_key) : _fn(desc.fn), _is_merge(desc.agg_expr.is_merge_agg), + _without_key(without_key), _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)) { bool nullable = true; if (desc.__isset.is_nullable) { @@ -83,8 +84,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) } Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, - AggFnEvaluator** result) { - *result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0]).release()); + const bool without_key, AggFnEvaluator** result) { + *result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0], without_key).release()); auto& agg_fn_evaluator = *result; int node_idx = 0; for (int i = 0; i < desc.nodes[0].num_children; ++i) { @@ -213,6 +214,12 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, _function = transform_to_sort_agg_function(_function, _argument_types_with_sort, _sort_description, state); } + + if (!_function->is_valid_signature(_without_key, argument_types, _data_type)) { + return Status::InvalidArgument("Argument of aggregate function {} is invalid.", + _fn.name.function_name); + } + _expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name); return Status::OK(); } @@ -320,6 +327,7 @@ AggFnEvaluator* AggFnEvaluator::clone(RuntimeState* state, ObjectPool* pool) { AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state) : _fn(evaluator._fn), _is_merge(evaluator._is_merge), + _without_key(evaluator._without_key), _argument_types_with_sort(evaluator._argument_types_with_sort), _real_argument_types(evaluator._real_argument_types), _return_type(evaluator._return_type), diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index 7dcd1b3e02bb474..d6562747fcc8084 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -50,7 +50,7 @@ class AggFnEvaluator { public: static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, - AggFnEvaluator** result); + const bool without_key, AggFnEvaluator** result); Status prepare(RuntimeState* state, const RowDescriptor& desc, const SlotDescriptor* intermediate_slot_desc, @@ -109,8 +109,10 @@ class AggFnEvaluator { const TFunction _fn; const bool _is_merge; + // We need this flag to distinguish between the two types of aggregation functions: + const bool _without_key; - AggFnEvaluator(const TExprNode& desc); + AggFnEvaluator(const TExprNode& desc, const bool without_key); AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state); Status _calc_argument_columns(Block* block);