Skip to content

Commit

Permalink
Merge pull request ClickHouse#54947 from amosbird/minmax-combinator
Browse files Browse the repository at this point in the history
Introduce -ArgMin/-ArgMax combinators.
  • Loading branch information
alexey-milovidov authored Oct 30, 2023
2 parents 88440d4 + ff86fad commit 3631e47
Show file tree
Hide file tree
Showing 43 changed files with 282 additions and 62 deletions.
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ReflowComments: false
AlignEscapedNewlinesLeft: false
AlignEscapedNewlines: DontAlign
AlignTrailingComments: false
InsertBraces: WrapLikely

# Not changed:
AccessModifierOffset: -4
Expand Down
11 changes: 10 additions & 1 deletion docs/en/sql-reference/aggregate-functions/combinators.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ INSERT INTO map_map VALUES
('2000-01-01', '2000-01-01 00:00:00', (['c', 'd', 'e'], [10, 10, 10])),
('2000-01-01', '2000-01-01 00:01:00', (['d', 'e', 'f'], [10, 10, 10])),
('2000-01-01', '2000-01-01 00:01:00', (['f', 'g', 'g'], [10, 10, 10]));

SELECT
timeslot,
sumMap(status),
Expand Down Expand Up @@ -317,6 +317,15 @@ FROM people
└────────┴───────────────────────────┘
```

## -ArgMin

The suffix -ArgMin can be appended to the name of any aggregate function. In this case, the aggregate function accepts an additional argument, which should be any comparable expression. The aggregate function processes only the rows that have the minimum value for the specified extra expression.

Examples: `sumArgMin(column, expr)`, `countArgMin(expr)`, `avgArgMin(x, expr)` and so on.

## -ArgMax

Similar to suffix -ArgMin but processes only the rows that have the maximum value for the specified extra expression.

## Related Content

Expand Down
2 changes: 1 addition & 1 deletion src/AggregateFunctions/AggregateFunctionFactory.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>

#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeNullable.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <unordered_set>

#include <AggregateFunctions/AggregateFunctionNull.h>
#include <AggregateFunctions/Combinators/AggregateFunctionNull.h>

#include <Columns/ColumnsNumber.h>

Expand Down
1 change: 0 additions & 1 deletion src/AggregateFunctions/AggregateFunctionSequenceNextNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <Common/assert_cast.h>

#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionNull.h>

#include <type_traits>
#include <bitset>
Expand Down
2 changes: 0 additions & 2 deletions src/AggregateFunctions/AggregateFunctionWindowFunnel.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include <IO/WriteHelpers.h>
#include <Common/assert_cast.h>

#include <AggregateFunctions/AggregateFunctionNull.h>

namespace DB
{
struct Settings;
Expand Down
11 changes: 6 additions & 5 deletions src/AggregateFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
include("${ClickHouse_SOURCE_DIR}/cmake/dbms_glob_sources.cmake")
add_headers_and_sources(clickhouse_aggregate_functions .)
add_headers_and_sources(clickhouse_aggregate_functions Combinators)

extract_into_parent_list(clickhouse_aggregate_functions_sources dbms_sources
IAggregateFunction.cpp
AggregateFunctionFactory.cpp
AggregateFunctionCombinatorFactory.cpp
AggregateFunctionState.cpp
Combinators/AggregateFunctionCombinatorFactory.cpp
Combinators/AggregateFunctionState.cpp
AggregateFunctionCount.cpp
parseAggregateFunctionParameters.cpp
)
extract_into_parent_list(clickhouse_aggregate_functions_headers dbms_headers
IAggregateFunction.h
IAggregateFunctionCombinator.h
Combinators/IAggregateFunctionCombinator.h
AggregateFunctionFactory.h
AggregateFunctionCombinatorFactory.h
AggregateFunctionState.h
Combinators/AggregateFunctionCombinatorFactory.h
Combinators/AggregateFunctionState.h
AggregateFunctionCount.cpp
FactoryHelpers.h
parseAggregateFunctionParameters.h
Expand Down
93 changes: 93 additions & 0 deletions src/AggregateFunctions/Combinators/AggregateFunctionArgMinMax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "AggregateFunctionArgMinMax.h"
#include "AggregateFunctionCombinatorFactory.h"

#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>

namespace DB
{

namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}

namespace
{
template <template <typename> class Data>
class AggregateFunctionCombinatorArgMinMax final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return Data<SingleValueDataGeneric<>>::name(); }

DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());

return DataTypes(arguments.begin(), arguments.end() - 1);
}

AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
const DataTypePtr & argument_type = arguments.back();
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<TYPE>>>>(nested_function, arguments, params); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

if (which.idx == TypeIndex::Date)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>>(
nested_function, arguments, params);
if (which.idx == TypeIndex::DateTime)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(
nested_function, arguments, params);
if (which.idx == TypeIndex::DateTime64)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DateTime64>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal32)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal32>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal64)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal64>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal128)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal128>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal256)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal256>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::String)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataString>>>(nested_function, arguments, params);

return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataGeneric<>>>>(nested_function, arguments, params);
}
};

template <typename Data>
struct AggregateFunctionArgMinDataCapitalized : AggregateFunctionMinData<Data>
{
static const char * name() { return "ArgMin"; }
};

template <typename Data>
struct AggregateFunctionArgMaxDataCapitalized : AggregateFunctionMaxData<Data>
{
static const char * name() { return "ArgMax"; }
};

}

void registerAggregateFunctionCombinatorMinMax(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorArgMinMax<AggregateFunctionArgMinDataCapitalized>>());
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorArgMinMax<AggregateFunctionArgMaxDataCapitalized>>());
}

}
111 changes: 111 additions & 0 deletions src/AggregateFunctions/Combinators/AggregateFunctionArgMinMax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#pragma once

#include <AggregateFunctions/IAggregateFunction.h>

namespace DB
{

template <typename Key>
class AggregateFunctionArgMinMax final : public IAggregateFunctionHelper<AggregateFunctionArgMinMax<Key>>
{
private:
AggregateFunctionPtr nested_function;
SerializationPtr serialization;
size_t key_col;
size_t key_offset;

Key & key(AggregateDataPtr __restrict place) const { return *reinterpret_cast<Key *>(place + key_offset); }
const Key & key(ConstAggregateDataPtr __restrict place) const { return *reinterpret_cast<const Key *>(place + key_offset); }

public:
AggregateFunctionArgMinMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionArgMinMax<Key>>{arguments, params, nested_function_->getResultType()}
, nested_function{nested_function_}
, serialization(arguments.back()->getDefaultSerialization())
, key_col{arguments.size() - 1}
, key_offset{(nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key) * alignof(Key)}
{
}

String getName() const override { return nested_function->getName() + Key::name(); }

bool isState() const override { return nested_function->isState(); }

bool isVersioned() const override { return nested_function->isVersioned(); }

size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }

size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }

bool allocatesMemoryInArena() const override { return nested_function->allocatesMemoryInArena() || Key::allocatesMemoryInArena(); }

bool hasTrivialDestructor() const override { return nested_function->hasTrivialDestructor(); }

size_t sizeOfData() const override { return key_offset + sizeof(Key); }

size_t alignOfData() const override { return nested_function->alignOfData(); }

void create(AggregateDataPtr __restrict place) const override
{
nested_function->create(place);
new (place + key_offset) Key;
}

void destroy(AggregateDataPtr __restrict place) const noexcept override { nested_function->destroy(place); }

void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override { nested_function->destroyUpToState(place); }

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (key(place).changeIfBetter(*columns[key_col], row_num, arena))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->add(place, columns, row_num, arena);
}
else if (key(place).isEqualTo(*columns[key_col], row_num))
{
nested_function->add(place, columns, row_num, arena);
}
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (key(place).changeIfBetter(key(rhs), arena))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->merge(place, rhs, arena);
}
else if (key(place).isEqualTo(key(rhs)))
{
nested_function->merge(place, rhs, arena);
}
}

void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_function->serialize(place, buf, version);
key(place).write(buf, *serialization);
}

void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_function->deserialize(place, buf, version, arena);
key(place).read(buf, *serialization, arena);
}

void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertResultInto(place, to, arena);
}

void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertMergeResultInto(place, to, arena);
}

AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
};

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <AggregateFunctions/AggregateFunctionArray.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <Common/typeid_cast.h>
#include "AggregateFunctionArray.h"
#include "AggregateFunctionCombinatorFactory.h"

#include <Common/typeid_cast.h>

namespace DB
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <Common/StringUtils/StringUtils.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include "AggregateFunctionCombinatorFactory.h"

#include <Common/StringUtils/StringUtils.h>

namespace DB
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <AggregateFunctions/IAggregateFunctionCombinator.h>

#include "IAggregateFunctionCombinator.h"

#include <string>
#include <unordered_map>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <AggregateFunctions/AggregateFunctionDistinct.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include "AggregateFunctionDistinct.h"
#include "AggregateFunctionCombinatorFactory.h"

#include <AggregateFunctions/Helpers.h>
#include <Common/typeid_cast.h>


namespace DB
{
struct Settings;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <AggregateFunctions/AggregateFunctionForEach.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include "AggregateFunctionForEach.h"
#include "AggregateFunctionCombinatorFactory.h"

#include <Common/typeid_cast.h>


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/AggregateFunctionIf.h>
#include "AggregateFunctionCombinatorFactory.h"
#include "AggregateFunctionIf.h"
#include "AggregateFunctionNull.h"

namespace DB
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "AggregateFunctionMap.h"
#include "AggregateFunctions/AggregateFunctionCombinatorFactory.h"
#include "Functions/FunctionHelpers.h"
#include "AggregateFunctionCombinatorFactory.h"

#include <Functions/FunctionHelpers.h>

namespace DB
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#pragma once

#include <unordered_map>
#include <base/sort.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnMap.h>
Expand All @@ -14,16 +13,16 @@
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <base/IPv4andIPv6.h>
#include "base/types.h"
#include <Common/formatIPv6.h>
#include <base/sort.h>
#include <base/types.h>
#include <Common/Arena.h>
#include "AggregateFunctions/AggregateFunctionFactory.h"
#include <Common/formatIPv6.h>

namespace DB
{
Expand Down
Loading

0 comments on commit 3631e47

Please sign in to comment.