Skip to content

Commit

Permalink
Add support for AggregateFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
aronbierbaum committed Mar 23, 2024
1 parent 04c9f69 commit e77c82f
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 3 deletions.
30 changes: 29 additions & 1 deletion clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .compilers.sqlcompiler import ClickHouseSQLCompiler
from .compilers.typecompiler import ClickHouseTypeCompiler
from .reflection import ClickHouseInspector
from .util import get_inner_spec
from .util import get_inner_spec, parse_arguments
from .. import types

# Column specifications
Expand Down Expand Up @@ -54,6 +54,8 @@
'_lowcardinality': types.LowCardinality,
'_tuple': types.Tuple,
'_map': types.Map,
'_aggregatefunction': types.AggregateFunction,
'_simpleaggregatefunction': types.SimpleAggregateFunction,
}


Expand Down Expand Up @@ -230,6 +232,32 @@ def _get_column_type(self, name, spec):
coltype = self.ischema_names['_lowcardinality']
return coltype(self._get_column_type(name, inner))

elif spec.startswith('AggregateFunction'):
params = spec[18:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_aggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('SimpleAggregateFunction'):
params = spec[24:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_simpleaggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('Tuple'):
inner = spec[6:-1]
coltype = self.ischema_names['_tuple']
Expand Down
26 changes: 26 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/typecompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,29 @@ def visit_map(self, type_, **kw):
self.process(key_type, **kw),
self.process(value_type, **kw)
)

def visit_aggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "AggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)

def visit_simpleaggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "SimpleAggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)
4 changes: 4 additions & 0 deletions clickhouse_sqlalchemy/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
'Nested',
'Tuple',
'Map',
'AggregateFunction',
'SimpleAggregateFunction',
]

from .common import String
Expand Down Expand Up @@ -66,6 +68,8 @@
from .common import Decimal
from .common import Tuple
from .common import Map
from .common import AggregateFunction
from .common import SimpleAggregateFunction
from .ip import IPv4
from .ip import IPv6
from .nested import Nested
51 changes: 50 additions & 1 deletion clickhouse_sqlalchemy/types/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy.sql.type_api import to_instance
from typing import Type, Union

from sqlalchemy import types
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.type_api import to_instance


class ClickHouseTypeEngine(types.TypeEngine):
Expand Down Expand Up @@ -197,3 +200,49 @@ def __init__(self, key_type, value_type):
self.key_type = key_type
self.value_type = value_type
super(Map, self).__init__()


class AggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'aggregatefunction'

def __init__(
self,
agg_func: Union[Function, str],
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(AggregateFunction, self).__init__()

def __repr__(self) -> str:
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]

if isinstance(self.agg_func, str):
agg_str = self.agg_func
else:
agg_str = f'sa.func.{self.agg_func}'

return f"AggregateFunction({agg_str}, {', '.join(type_strs)})"


class SimpleAggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'simpleaggregatefunction'

def __init__(
self,
agg_func: Union[Function, str],
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(SimpleAggregateFunction, self).__init__()

def __repr__(self) -> str:
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]

if isinstance(self.agg_func, str):
agg_str = self.agg_func
else:
agg_str = f'sa.func.{self.agg_func}'

return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})"
20 changes: 20 additions & 0 deletions docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,26 @@ You can specify cluster for materialized view in inner table definition.
{'clickhouse_cluster': 'my_cluster'}
)
Materialized views can also store the aggregated data in a table using the
``AggregatingMergeTree`` engine. The aggregate columns are defined using
``AggregateFunction`` or ``SimpleAggregateFunction``.

.. code-block:: python
# Define storage for Materialized View
class GroupedStatistics(Base):
date = Column(types.Date, primary_key=True)
metric1 = Column(SimpleAggregateFunction(sa.func.sum(), types.Int32), nullable=False)
__table_args__ = (
engines.AggregatingMergeTree(
partition_by=func.toYYYYMM(date),
order_by=(date, )
),
)
Basic DDL support
-----------------

Expand Down
33 changes: 33 additions & 0 deletions tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,39 @@ def test_create_table_map(self):
'ENGINE = Memory'
)

def test_create_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column('total', types.AggregateFunction(func.sum(), types.UInt32)),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total AggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

@require_server_version(22, 8, 21)
def test_create_simple_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column(
'total', types.SimpleAggregateFunction(
func.sum(), types.UInt32
)
),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total SimpleAggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

def test_table_create_on_cluster(self):
create_sql = (
'CREATE TABLE t1 ON CLUSTER test_cluster '
Expand Down
65 changes: 64 additions & 1 deletion tests/test_reflection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import enum
from sqlalchemy import Column, inspect, types as sa_types
from sqlalchemy import Column, func, inspect, types as sa_types

from clickhouse_sqlalchemy import types, engines, Table

from tests.testcase import BaseTestCase
from tests.util import require_server_version, with_native_and_http_sessions

Expand Down Expand Up @@ -166,3 +167,65 @@ def test_datetime(self):

self.assertIsInstance(coltype, types.DateTime)
self.assertIsNone(coltype.timezone)

def test_aggregate_function(self):
coltype = self._type_round_trip(
types.AggregateFunction(func.sum(), types.UInt16)
)[0]['type']

self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.UInt16)

coltype = self._type_round_trip(
types.AggregateFunction('quantiles(0.5, 0.9)', types.UInt32)
)[0]['type']
self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'quantiles(0.5, 0.9)')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.UInt32)

coltype = self._type_round_trip(
types.AggregateFunction(
func.argMin(), types.Float32, types.Float32
)
)[0]['type']
self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'argMin')
self.assertEqual(len(coltype.nested_types), 2)
self.assertIsInstance(coltype.nested_types[0], types.Float32)
self.assertIsInstance(coltype.nested_types[1], types.Float32)

coltype = self._type_round_trip(
types.AggregateFunction(
'sum', types.Decimal(18, 2)
)
)[0]['type']
self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.Decimal)
self.assertEqual(coltype.nested_types[0].precision, 18)
self.assertEqual(coltype.nested_types[0].scale, 2)

@require_server_version(22, 8, 21)
def test_simple_aggregate_function(self):
coltype = self._type_round_trip(
types.SimpleAggregateFunction(func.sum(), types.UInt64)
)[0]['type']

self.assertIsInstance(coltype, types.SimpleAggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.UInt64)

coltype = self._type_round_trip(
types.SimpleAggregateFunction(
'sum', types.Float64
)
)[0]['type']
self.assertIsInstance(coltype, types.SimpleAggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.Float64)

0 comments on commit e77c82f

Please sign in to comment.