diff --git a/clickhouse_sqlalchemy/drivers/base.py b/clickhouse_sqlalchemy/drivers/base.py index 79855709..0cf2b2ef 100644 --- a/clickhouse_sqlalchemy/drivers/base.py +++ b/clickhouse_sqlalchemy/drivers/base.py @@ -13,7 +13,7 @@ from .compilers.sqlcompiler import ClickHouseSQLCompiler from .compilers.typecompiler import ClickHouseTypeCompiler from .reflection import ClickHouseInspector -from .util import get_inner_spec +from .util import get_inner_spec, parse_arguments from .. import types # Column specifications @@ -54,6 +54,8 @@ '_lowcardinality': types.LowCardinality, '_tuple': types.Tuple, '_map': types.Map, + '_aggregatefunction': types.AggregateFunction, + '_simpleaggregatefunction': types.SimpleAggregateFunction, } @@ -230,6 +232,32 @@ def _get_column_type(self, name, spec): coltype = self.ischema_names['_lowcardinality'] return coltype(self._get_column_type(name, inner)) + elif spec.startswith('AggregateFunction'): + params = spec[18:-1] + + arguments = parse_arguments(params) + agg_func, inner = arguments[0], arguments[1:] + + inner_types = [ + self._get_column_type(name, param) + for param in inner + ] + coltype = self.ischema_names['_aggregatefunction'] + return coltype(agg_func, *inner_types) + + elif spec.startswith('SimpleAggregateFunction'): + params = spec[24:-1] + + arguments = parse_arguments(params) + agg_func, inner = arguments[0], arguments[1:] + + inner_types = [ + self._get_column_type(name, param) + for param in inner + ] + coltype = self.ischema_names['_simpleaggregatefunction'] + return coltype(agg_func, *inner_types) + elif spec.startswith('Tuple'): inner = spec[6:-1] coltype = self.ischema_names['_tuple'] diff --git a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py index 26647841..dbe558fc 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py @@ -131,3 +131,29 @@ def visit_map(self, type_, **kw): self.process(key_type, **kw), self.process(value_type, **kw) ) + + def visit_aggregatefunction(self, type_, **kw): + types = [type_api.to_instance(val) for val in type_.nested_types] + type_strings = [self.process(val, **kw) for val in types] + + if isinstance(type_.agg_func, str): + agg_str = type_.agg_func + else: + agg_str = str(type_.agg_func.compile(dialect=self.dialect)) + + return "AggregateFunction(%s, %s)" % ( + agg_str, ", ".join(type_strings) + ) + + def visit_simpleaggregatefunction(self, type_, **kw): + types = [type_api.to_instance(val) for val in type_.nested_types] + type_strings = [self.process(val, **kw) for val in types] + + if isinstance(type_.agg_func, str): + agg_str = type_.agg_func + else: + agg_str = str(type_.agg_func.compile(dialect=self.dialect)) + + return "SimpleAggregateFunction(%s, %s)" % ( + agg_str, ", ".join(type_strings) + ) diff --git a/clickhouse_sqlalchemy/drivers/util.py b/clickhouse_sqlalchemy/drivers/util.py index 25783daf..ca5ca887 100644 --- a/clickhouse_sqlalchemy/drivers/util.py +++ b/clickhouse_sqlalchemy/drivers/util.py @@ -16,3 +16,30 @@ def get_inner_spec(spec): break return spec[offset + 1:i] + + +def parse_arguments(param_string): + """ + Given a string of function arguments, parse them into a tuple. + """ + params = [] + bracket_level = 0 + current_param = '' + + for char in param_string: + if char == '(': + bracket_level += 1 + elif char == ')': + bracket_level -= 1 + elif char == ',' and bracket_level == 0: + params.append(current_param.strip()) + current_param = '' + continue + + current_param += char + + # Append the last parameter + if current_param: + params.append(current_param.strip()) + + return tuple(params) diff --git a/clickhouse_sqlalchemy/types/__init__.py b/clickhouse_sqlalchemy/types/__init__.py index 502e8a0f..372948ff 100644 --- a/clickhouse_sqlalchemy/types/__init__.py +++ b/clickhouse_sqlalchemy/types/__init__.py @@ -33,6 +33,8 @@ 'Nested', 'Tuple', 'Map', + 'AggregateFunction', + 'SimpleAggregateFunction', ] from .common import String @@ -66,6 +68,8 @@ from .common import Decimal from .common import Tuple from .common import Map +from .common import AggregateFunction +from .common import SimpleAggregateFunction from .ip import IPv4 from .ip import IPv6 from .nested import Nested diff --git a/clickhouse_sqlalchemy/types/common.py b/clickhouse_sqlalchemy/types/common.py index 947a8708..1620bde0 100644 --- a/clickhouse_sqlalchemy/types/common.py +++ b/clickhouse_sqlalchemy/types/common.py @@ -1,5 +1,8 @@ -from sqlalchemy.sql.type_api import to_instance +from typing import Type, Union + from sqlalchemy import types +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.type_api import to_instance class ClickHouseTypeEngine(types.TypeEngine): @@ -37,6 +40,11 @@ def __init__(self, item_type): self.item_type_impl = to_instance(item_type) super(Array, self).__init__() + def __repr__(self): + nested_type_str = \ + f'{self.item_type_impl.__module__}.{self.item_type_impl!r}' + return f'Array({nested_type_str})' + @property def python_type(self): return list @@ -73,6 +81,10 @@ def __init__(self, nested_type): self.nested_type = to_instance(nested_type) super(LowCardinality, self).__init__() + def __repr__(self): + nested_type_str = f'{self.nested_type.__module__}.{self.nested_type!r}' + return f'LowCardinality({nested_type_str})' + class Int8(Int): __visit_name__ = 'int8' @@ -188,3 +200,49 @@ def __init__(self, key_type, value_type): self.key_type = key_type self.value_type = value_type super(Map, self).__init__() + + +class AggregateFunction(ClickHouseTypeEngine): + __visit_name__ = 'aggregatefunction' + + def __init__( + self, + agg_func: Union[Function, str], + *nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine], + ): + self.agg_func = agg_func + self.nested_types = [to_instance(val) for val in nested_types] + super(AggregateFunction, self).__init__() + + def __repr__(self) -> str: + type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types] + + if isinstance(self.agg_func, str): + agg_str = self.agg_func + else: + agg_str = f'sa.func.{self.agg_func}' + + return f"AggregateFunction({agg_str}, {', '.join(type_strs)})" + + +class SimpleAggregateFunction(ClickHouseTypeEngine): + __visit_name__ = 'simpleaggregatefunction' + + def __init__( + self, + agg_func: Union[Function, str], + *nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine], + ): + self.agg_func = agg_func + self.nested_types = [to_instance(val) for val in nested_types] + super(SimpleAggregateFunction, self).__init__() + + def __repr__(self) -> str: + type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types] + + if isinstance(self.agg_func, str): + agg_str = self.agg_func + else: + agg_str = f'sa.func.{self.agg_func}' + + return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})" diff --git a/docs/features.rst b/docs/features.rst index 1f4f123b..20dce7e8 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -599,6 +599,26 @@ You can specify cluster for materialized view in inner table definition. {'clickhouse_cluster': 'my_cluster'} ) +Materialized views can also store the aggregated data in a table using the +``AggregatingMergeTree`` engine. The aggregate columns are defined using +``AggregateFunction`` or ``SimpleAggregateFunction``. + + .. code-block:: python + + + # Define storage for Materialized View + class GroupedStatistics(Base): + date = Column(types.Date, primary_key=True) + metric1 = Column(SimpleAggregateFunction(sa.func.sum(), types.Int32), nullable=False) + + __table_args__ = ( + engines.AggregatingMergeTree( + partition_by=func.toYYYYMM(date), + order_by=(date, ) + ), + ) + + Basic DDL support ----------------- diff --git a/tests/drivers/test_util.py b/tests/drivers/test_util.py new file mode 100644 index 00000000..daab91a0 --- /dev/null +++ b/tests/drivers/test_util.py @@ -0,0 +1,50 @@ +from unittest import TestCase + +from clickhouse_sqlalchemy.drivers.util import get_inner_spec, parse_arguments + + +class GetInnerSpecTestCase(TestCase): + def test_get_inner_spec(self): + self.assertEqual( + get_inner_spec("DateTime('Europe/Paris')"), "'Europe/Paris'" + ) + self.assertEqual(get_inner_spec('Decimal(18, 2)'), "18, 2") + self.assertEqual(get_inner_spec('DateTime64(3)'), "3") + + +class ParseArgumentsTestCase(TestCase): + def test_parse_arguments(self): + self.assertEqual( + parse_arguments('uniq, UInt64'), ('uniq', 'UInt64') + ) + self.assertEqual( + parse_arguments('anyIf, String, UInt8'), + ('anyIf', 'String', 'UInt8') + ) + self.assertEqual( + parse_arguments('quantiles(0.5, 0.9), UInt64'), + ('quantiles(0.5, 0.9)', 'UInt64') + ) + self.assertEqual( + parse_arguments('sum, Int64, Int64'), ('sum', 'Int64', 'Int64') + ) + self.assertEqual( + parse_arguments('sum, Nullable(Int64), Int64'), + ('sum', 'Nullable(Int64)', 'Int64') + ) + self.assertEqual( + parse_arguments('Float32, Decimal(18, 2)'), + ('Float32', 'Decimal(18, 2)') + ) + self.assertEqual( + parse_arguments('sum, Float32, Decimal(18, 2)'), + ('sum', 'Float32', 'Decimal(18, 2)') + ) + self.assertEqual( + parse_arguments('quantiles(0.5, 0.9), UInt64'), + ('quantiles(0.5, 0.9)', 'UInt64') + ) + self.assertEqual( + parse_arguments("sumIf(total, status = 'accepted'), Float32"), + ("sumIf(total, status = 'accepted')", "Float32") + ) diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 10f8b932..6ff22997 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -311,6 +311,39 @@ def test_create_table_map(self): 'ENGINE = Memory' ) + def test_create_aggregate_function(self): + table = Table( + 't1', self.metadata(), + Column('total', types.AggregateFunction(func.sum(), types.UInt32)), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'total AggregateFunction(sum(), UInt32)) ' + 'ENGINE = Memory' + ) + + @require_server_version(22, 8, 21) + def test_create_simple_aggregate_function(self): + table = Table( + 't1', self.metadata(), + Column( + 'total', types.SimpleAggregateFunction( + func.sum(), types.UInt32 + ) + ), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'total SimpleAggregateFunction(sum(), UInt32)) ' + 'ENGINE = Memory' + ) + def test_table_create_on_cluster(self): create_sql = ( 'CREATE TABLE t1 ON CLUSTER test_cluster ' diff --git a/tests/test_reflection.py b/tests/test_reflection.py index 4c5a3f41..50ddbb90 100644 --- a/tests/test_reflection.py +++ b/tests/test_reflection.py @@ -1,7 +1,8 @@ import enum -from sqlalchemy import Column, inspect, types as sa_types +from sqlalchemy import Column, func, inspect, types as sa_types from clickhouse_sqlalchemy import types, engines, Table + from tests.testcase import BaseTestCase from tests.util import require_server_version, with_native_and_http_sessions @@ -166,3 +167,65 @@ def test_datetime(self): self.assertIsInstance(coltype, types.DateTime) self.assertIsNone(coltype.timezone) + + def test_aggregate_function(self): + coltype = self._type_round_trip( + types.AggregateFunction(func.sum(), types.UInt16) + )[0]['type'] + + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt16) + + coltype = self._type_round_trip( + types.AggregateFunction('quantiles(0.5, 0.9)', types.UInt32) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'quantiles(0.5, 0.9)') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt32) + + coltype = self._type_round_trip( + types.AggregateFunction( + func.argMin(), types.Float32, types.Float32 + ) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'argMin') + self.assertEqual(len(coltype.nested_types), 2) + self.assertIsInstance(coltype.nested_types[0], types.Float32) + self.assertIsInstance(coltype.nested_types[1], types.Float32) + + coltype = self._type_round_trip( + types.AggregateFunction( + 'sum', types.Decimal(18, 2) + ) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.Decimal) + self.assertEqual(coltype.nested_types[0].precision, 18) + self.assertEqual(coltype.nested_types[0].scale, 2) + + @require_server_version(22, 8, 21) + def test_simple_aggregate_function(self): + coltype = self._type_round_trip( + types.SimpleAggregateFunction(func.sum(), types.UInt64) + )[0]['type'] + + self.assertIsInstance(coltype, types.SimpleAggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt64) + + coltype = self._type_round_trip( + types.SimpleAggregateFunction( + 'sum', types.Float64 + ) + )[0]['type'] + self.assertIsInstance(coltype, types.SimpleAggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.Float64)