diff --git a/clickhouse_sqlalchemy/drivers/base.py b/clickhouse_sqlalchemy/drivers/base.py index 79855709..0cf2b2ef 100644 --- a/clickhouse_sqlalchemy/drivers/base.py +++ b/clickhouse_sqlalchemy/drivers/base.py @@ -13,7 +13,7 @@ from .compilers.sqlcompiler import ClickHouseSQLCompiler from .compilers.typecompiler import ClickHouseTypeCompiler from .reflection import ClickHouseInspector -from .util import get_inner_spec +from .util import get_inner_spec, parse_arguments from .. import types # Column specifications @@ -54,6 +54,8 @@ '_lowcardinality': types.LowCardinality, '_tuple': types.Tuple, '_map': types.Map, + '_aggregatefunction': types.AggregateFunction, + '_simpleaggregatefunction': types.SimpleAggregateFunction, } @@ -230,6 +232,32 @@ def _get_column_type(self, name, spec): coltype = self.ischema_names['_lowcardinality'] return coltype(self._get_column_type(name, inner)) + elif spec.startswith('AggregateFunction'): + params = spec[18:-1] + + arguments = parse_arguments(params) + agg_func, inner = arguments[0], arguments[1:] + + inner_types = [ + self._get_column_type(name, param) + for param in inner + ] + coltype = self.ischema_names['_aggregatefunction'] + return coltype(agg_func, *inner_types) + + elif spec.startswith('SimpleAggregateFunction'): + params = spec[24:-1] + + arguments = parse_arguments(params) + agg_func, inner = arguments[0], arguments[1:] + + inner_types = [ + self._get_column_type(name, param) + for param in inner + ] + coltype = self.ischema_names['_simpleaggregatefunction'] + return coltype(agg_func, *inner_types) + elif spec.startswith('Tuple'): inner = spec[6:-1] coltype = self.ischema_names['_tuple'] diff --git a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py index 26647841..dbe558fc 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py @@ -131,3 +131,29 @@ def visit_map(self, type_, **kw): self.process(key_type, **kw), self.process(value_type, **kw) ) + + def visit_aggregatefunction(self, type_, **kw): + types = [type_api.to_instance(val) for val in type_.nested_types] + type_strings = [self.process(val, **kw) for val in types] + + if isinstance(type_.agg_func, str): + agg_str = type_.agg_func + else: + agg_str = str(type_.agg_func.compile(dialect=self.dialect)) + + return "AggregateFunction(%s, %s)" % ( + agg_str, ", ".join(type_strings) + ) + + def visit_simpleaggregatefunction(self, type_, **kw): + types = [type_api.to_instance(val) for val in type_.nested_types] + type_strings = [self.process(val, **kw) for val in types] + + if isinstance(type_.agg_func, str): + agg_str = type_.agg_func + else: + agg_str = str(type_.agg_func.compile(dialect=self.dialect)) + + return "SimpleAggregateFunction(%s, %s)" % ( + agg_str, ", ".join(type_strings) + ) diff --git a/clickhouse_sqlalchemy/types/__init__.py b/clickhouse_sqlalchemy/types/__init__.py index 502e8a0f..372948ff 100644 --- a/clickhouse_sqlalchemy/types/__init__.py +++ b/clickhouse_sqlalchemy/types/__init__.py @@ -33,6 +33,8 @@ 'Nested', 'Tuple', 'Map', + 'AggregateFunction', + 'SimpleAggregateFunction', ] from .common import String @@ -66,6 +68,8 @@ from .common import Decimal from .common import Tuple from .common import Map +from .common import AggregateFunction +from .common import SimpleAggregateFunction from .ip import IPv4 from .ip import IPv6 from .nested import Nested diff --git a/clickhouse_sqlalchemy/types/common.py b/clickhouse_sqlalchemy/types/common.py index e300e719..7c7a5410 100644 --- a/clickhouse_sqlalchemy/types/common.py +++ b/clickhouse_sqlalchemy/types/common.py @@ -1,5 +1,6 @@ -from sqlalchemy.sql.type_api import to_instance from sqlalchemy import types +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.type_api import to_instance class ClickHouseTypeEngine(types.TypeEngine): @@ -197,3 +198,43 @@ def __init__(self, key_type, value_type): self.key_type = key_type self.value_type = value_type super(Map, self).__init__() + + +class AggregateFunction(ClickHouseTypeEngine): + __visit_name__ = 'aggregatefunction' + + def __init__(self, agg_func: Function | str, + *nested_types: ClickHouseTypeEngine): + self.agg_func = agg_func + self.nested_types = [to_instance(val) for val in nested_types] + super(AggregateFunction, self).__init__() + + def __repr__(self) -> str: + type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types] + + if isinstance(self.agg_func, str): + agg_str = self.agg_func + else: + agg_str = f'sa.func.{self.agg_func}' + + return f"AggregateFunction({agg_str}, {', '.join(type_strs)})" + + +class SimpleAggregateFunction(ClickHouseTypeEngine): + __visit_name__ = 'simpleaggregatefunction' + + def __init__(self, agg_func: Function | str, + *nested_types: ClickHouseTypeEngine): + self.agg_func = agg_func + self.nested_types = [to_instance(val) for val in nested_types] + super(SimpleAggregateFunction, self).__init__() + + def __repr__(self) -> str: + type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types] + + if isinstance(self.agg_func, str): + agg_str = self.agg_func + else: + agg_str = f'sa.func.{self.agg_func}' + + return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})" diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 10f8b932..a48a4cc1 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -311,6 +311,38 @@ def test_create_table_map(self): 'ENGINE = Memory' ) + def test_create_aggregate_function(self): + table = Table( + 't1', self.metadata(), + Column('total', types.AggregateFunction(func.sum(), types.UInt32)), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'total AggregateFunction(sum(), UInt32)) ' + 'ENGINE = Memory' + ) + + def test_create_simple_aggregate_function(self): + table = Table( + 't1', self.metadata(), + Column( + 'total', types.SimpleAggregateFunction( + func.sum(), types.UInt32 + ) + ), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'total SimpleAggregateFunction(sum(), UInt32)) ' + 'ENGINE = Memory' + ) + def test_table_create_on_cluster(self): create_sql = ( 'CREATE TABLE t1 ON CLUSTER test_cluster ' diff --git a/tests/test_reflection.py b/tests/test_reflection.py index 4c5a3f41..b424a943 100644 --- a/tests/test_reflection.py +++ b/tests/test_reflection.py @@ -1,7 +1,8 @@ import enum -from sqlalchemy import Column, inspect, types as sa_types +from sqlalchemy import Column, func, inspect, types as sa_types from clickhouse_sqlalchemy import types, engines, Table + from tests.testcase import BaseTestCase from tests.util import require_server_version, with_native_and_http_sessions @@ -166,3 +167,64 @@ def test_datetime(self): self.assertIsInstance(coltype, types.DateTime) self.assertIsNone(coltype.timezone) + + def test_aggregate_function(self): + coltype = self._type_round_trip( + types.AggregateFunction(func.sum(), types.UInt16) + )[0]['type'] + + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt16) + + coltype = self._type_round_trip( + types.AggregateFunction('quantiles(0.5, 0.9)', types.UInt32) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'quantiles(0.5, 0.9)') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt32) + + coltype = self._type_round_trip( + types.AggregateFunction( + func.rankCorr(), types.Float32, types.Float32 + ) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'rankCorr') + self.assertEqual(len(coltype.nested_types), 2) + self.assertIsInstance(coltype.nested_types[0], types.Float32) + self.assertIsInstance(coltype.nested_types[1], types.Float32) + + coltype = self._type_round_trip( + types.AggregateFunction( + 'sum', types.Decimal(18, 2) + ) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.Decimal) + self.assertEqual(coltype.nested_types[0].precision, 18) + self.assertEqual(coltype.nested_types[0].scale, 2) + + def test_simple_aggregate_function(self): + coltype = self._type_round_trip( + types.SimpleAggregateFunction(func.sum(), types.UInt64) + )[0]['type'] + + self.assertIsInstance(coltype, types.SimpleAggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt64) + + coltype = self._type_round_trip( + types.SimpleAggregateFunction( + 'sum', types.Float64 + ) + )[0]['type'] + self.assertIsInstance(coltype, types.SimpleAggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.Float64)