Skip to content

Commit

Permalink
Add support for AggregateFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
aronbierbaum committed Mar 4, 2024
1 parent 600db62 commit 21804f2
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 3 deletions.
30 changes: 29 additions & 1 deletion clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .compilers.sqlcompiler import ClickHouseSQLCompiler
from .compilers.typecompiler import ClickHouseTypeCompiler
from .reflection import ClickHouseInspector
from .util import get_inner_spec
from .util import get_inner_spec, parse_arguments
from .. import types

# Column specifications
Expand Down Expand Up @@ -54,6 +54,8 @@
'_lowcardinality': types.LowCardinality,
'_tuple': types.Tuple,
'_map': types.Map,
'_aggregatefunction': types.AggregateFunction,
'_simpleaggregatefunction': types.SimpleAggregateFunction,
}


Expand Down Expand Up @@ -230,6 +232,32 @@ def _get_column_type(self, name, spec):
coltype = self.ischema_names['_lowcardinality']
return coltype(self._get_column_type(name, inner))

elif spec.startswith('AggregateFunction'):
params = spec[18:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_aggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('SimpleAggregateFunction'):
params = spec[24:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_simpleaggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('Tuple'):
inner = spec[6:-1]
coltype = self.ischema_names['_tuple']
Expand Down
26 changes: 26 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/typecompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,29 @@ def visit_map(self, type_, **kw):
self.process(key_type, **kw),
self.process(value_type, **kw)
)

def visit_aggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "AggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)

def visit_simpleaggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "SimpleAggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)
4 changes: 4 additions & 0 deletions clickhouse_sqlalchemy/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
'Nested',
'Tuple',
'Map',
'AggregateFunction',
'SimpleAggregateFunction',
]

from .common import String
Expand Down Expand Up @@ -66,6 +68,8 @@
from .common import Decimal
from .common import Tuple
from .common import Map
from .common import AggregateFunction
from .common import SimpleAggregateFunction
from .ip import IPv4
from .ip import IPv6
from .nested import Nested
23 changes: 22 additions & 1 deletion clickhouse_sqlalchemy/types/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy.sql.type_api import to_instance
from sqlalchemy import types
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.type_api import to_instance


class ClickHouseTypeEngine(types.TypeEngine):
Expand Down Expand Up @@ -188,3 +189,23 @@ def __init__(self, key_type, value_type):
self.key_type = key_type
self.value_type = value_type
super(Map, self).__init__()


class AggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'aggregatefunction'

def __init__(self, agg_func: Function | str,
*nested_types: ClickHouseTypeEngine):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(AggregateFunction, self).__init__()


class SimpleAggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'simpleaggregatefunction'

def __init__(self, agg_func: Function | str,
*nested_types: ClickHouseTypeEngine):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(SimpleAggregateFunction, self).__init__()
32 changes: 32 additions & 0 deletions tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,38 @@ def test_create_table_map(self):
'ENGINE = Memory'
)

def test_create_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column('total', types.AggregateFunction(func.sum(), types.UInt32)),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total AggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

def test_create_simple_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column(
'total', types.SimpleAggregateFunction(
func.sum(), types.UInt32
)
),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total SimpleAggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

def test_table_create_on_cluster(self):
create_sql = (
'CREATE TABLE t1 ON CLUSTER test_cluster '
Expand Down
64 changes: 63 additions & 1 deletion tests/test_reflection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import enum
from sqlalchemy import Column, inspect, types as sa_types
from sqlalchemy import Column, func, inspect, types as sa_types

from clickhouse_sqlalchemy import types, engines, Table

from tests.testcase import BaseTestCase
from tests.util import require_server_version, with_native_and_http_sessions

Expand Down Expand Up @@ -166,3 +167,64 @@ def test_datetime(self):

self.assertIsInstance(coltype, types.DateTime)
self.assertIsNone(coltype.timezone)

def test_aggregate_function(self):
coltype = self._type_round_trip(
types.AggregateFunction(func.sum(), types.UInt16)
)[0]['type']

self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.UInt16)

coltype = self._type_round_trip(
types.AggregateFunction('quantiles(0.5, 0.9)', types.UInt32)
)[0]['type']
self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'quantiles(0.5, 0.9)')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.UInt32)

coltype = self._type_round_trip(
types.AggregateFunction(
func.rankCorr(), types.Float32, types.Float32
)
)[0]['type']
self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'rankCorr')
self.assertEqual(len(coltype.nested_types), 2)
self.assertIsInstance(coltype.nested_types[0], types.Float32)
self.assertIsInstance(coltype.nested_types[1], types.Float32)

coltype = self._type_round_trip(
types.AggregateFunction(
'sum', types.Decimal(18, 2)
)
)[0]['type']
self.assertIsInstance(coltype, types.AggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.Decimal)
self.assertEqual(coltype.nested_types[0].precision, 18)
self.assertEqual(coltype.nested_types[0].scale, 2)

def test_simple_aggregate_function(self):
coltype = self._type_round_trip(
types.SimpleAggregateFunction(func.sum(), types.UInt64)
)[0]['type']

self.assertIsInstance(coltype, types.SimpleAggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.UInt64)

coltype = self._type_round_trip(
types.SimpleAggregateFunction(
'sum', types.Float64
)
)[0]['type']
self.assertIsInstance(coltype, types.SimpleAggregateFunction)
self.assertEqual(coltype.agg_func, 'sum')
self.assertEqual(len(coltype.nested_types), 1)
self.assertIsInstance(coltype.nested_types[0], types.Float64)

0 comments on commit 21804f2

Please sign in to comment.