Skip to content

Commit

Permalink
Merge pull request #297 from aronbierbaum/add_aggregate_types
Browse files Browse the repository at this point in the history
Add aggregate types
  • Loading branch information
xzkostyan authored Mar 26, 2024
2 parents 0794e94 + e77c82f commit 83c5b46
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 3 deletions.
30 changes: 29 additions & 1 deletion clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .compilers.sqlcompiler import ClickHouseSQLCompiler
from .compilers.typecompiler import ClickHouseTypeCompiler
from .reflection import ClickHouseInspector
from .util import get_inner_spec
from .util import get_inner_spec, parse_arguments
from .. import types

# Column specifications
Expand Down Expand Up @@ -54,6 +54,8 @@
'_lowcardinality': types.LowCardinality,
'_tuple': types.Tuple,
'_map': types.Map,
'_aggregatefunction': types.AggregateFunction,
'_simpleaggregatefunction': types.SimpleAggregateFunction,
}


Expand Down Expand Up @@ -230,6 +232,32 @@ def _get_column_type(self, name, spec):
coltype = self.ischema_names['_lowcardinality']
return coltype(self._get_column_type(name, inner))

elif spec.startswith('AggregateFunction'):
params = spec[18:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_aggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('SimpleAggregateFunction'):
params = spec[24:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_simpleaggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('Tuple'):
inner = spec[6:-1]
coltype = self.ischema_names['_tuple']
Expand Down
26 changes: 26 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/typecompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,29 @@ def visit_map(self, type_, **kw):
self.process(key_type, **kw),
self.process(value_type, **kw)
)

def visit_aggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "AggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)

def visit_simpleaggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "SimpleAggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)
27 changes: 27 additions & 0 deletions clickhouse_sqlalchemy/drivers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,30 @@ def get_inner_spec(spec):
break

return spec[offset + 1:i]


def parse_arguments(param_string):
"""
Given a string of function arguments, parse them into a tuple.
"""
params = []
bracket_level = 0
current_param = ''

for char in param_string:
if char == '(':
bracket_level += 1
elif char == ')':
bracket_level -= 1
elif char == ',' and bracket_level == 0:
params.append(current_param.strip())
current_param = ''
continue

current_param += char

# Append the last parameter
if current_param:
params.append(current_param.strip())

return tuple(params)
4 changes: 4 additions & 0 deletions clickhouse_sqlalchemy/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
'Nested',
'Tuple',
'Map',
'AggregateFunction',
'SimpleAggregateFunction',
]

from .common import String
Expand Down Expand Up @@ -66,6 +68,8 @@
from .common import Decimal
from .common import Tuple
from .common import Map
from .common import AggregateFunction
from .common import SimpleAggregateFunction
from .ip import IPv4
from .ip import IPv6
from .nested import Nested
60 changes: 59 additions & 1 deletion clickhouse_sqlalchemy/types/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy.sql.type_api import to_instance
from typing import Type, Union

from sqlalchemy import types
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.type_api import to_instance


class ClickHouseTypeEngine(types.TypeEngine):
Expand Down Expand Up @@ -37,6 +40,11 @@ def __init__(self, item_type):
self.item_type_impl = to_instance(item_type)
super(Array, self).__init__()

def __repr__(self):
nested_type_str = \
f'{self.item_type_impl.__module__}.{self.item_type_impl!r}'
return f'Array({nested_type_str})'

@property
def python_type(self):
return list
Expand Down Expand Up @@ -73,6 +81,10 @@ def __init__(self, nested_type):
self.nested_type = to_instance(nested_type)
super(LowCardinality, self).__init__()

def __repr__(self):
nested_type_str = f'{self.nested_type.__module__}.{self.nested_type!r}'
return f'LowCardinality({nested_type_str})'


class Int8(Int):
__visit_name__ = 'int8'
Expand Down Expand Up @@ -188,3 +200,49 @@ def __init__(self, key_type, value_type):
self.key_type = key_type
self.value_type = value_type
super(Map, self).__init__()


class AggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'aggregatefunction'

def __init__(
self,
agg_func: Union[Function, str],
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(AggregateFunction, self).__init__()

def __repr__(self) -> str:
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]

if isinstance(self.agg_func, str):
agg_str = self.agg_func
else:
agg_str = f'sa.func.{self.agg_func}'

return f"AggregateFunction({agg_str}, {', '.join(type_strs)})"


class SimpleAggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'simpleaggregatefunction'

def __init__(
self,
agg_func: Union[Function, str],
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(SimpleAggregateFunction, self).__init__()

def __repr__(self) -> str:
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]

if isinstance(self.agg_func, str):
agg_str = self.agg_func
else:
agg_str = f'sa.func.{self.agg_func}'

return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})"
20 changes: 20 additions & 0 deletions docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,26 @@ You can specify cluster for materialized view in inner table definition.
{'clickhouse_cluster': 'my_cluster'}
)
Materialized views can also store the aggregated data in a table using the
``AggregatingMergeTree`` engine. The aggregate columns are defined using
``AggregateFunction`` or ``SimpleAggregateFunction``.

.. code-block:: python
# Define storage for Materialized View
class GroupedStatistics(Base):
date = Column(types.Date, primary_key=True)
metric1 = Column(SimpleAggregateFunction(sa.func.sum(), types.Int32), nullable=False)
__table_args__ = (
engines.AggregatingMergeTree(
partition_by=func.toYYYYMM(date),
order_by=(date, )
),
)
Basic DDL support
-----------------

Expand Down
50 changes: 50 additions & 0 deletions tests/drivers/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest import TestCase

from clickhouse_sqlalchemy.drivers.util import get_inner_spec, parse_arguments


class GetInnerSpecTestCase(TestCase):
def test_get_inner_spec(self):
self.assertEqual(
get_inner_spec("DateTime('Europe/Paris')"), "'Europe/Paris'"
)
self.assertEqual(get_inner_spec('Decimal(18, 2)'), "18, 2")
self.assertEqual(get_inner_spec('DateTime64(3)'), "3")


class ParseArgumentsTestCase(TestCase):
def test_parse_arguments(self):
self.assertEqual(
parse_arguments('uniq, UInt64'), ('uniq', 'UInt64')
)
self.assertEqual(
parse_arguments('anyIf, String, UInt8'),
('anyIf', 'String', 'UInt8')
)
self.assertEqual(
parse_arguments('quantiles(0.5, 0.9), UInt64'),
('quantiles(0.5, 0.9)', 'UInt64')
)
self.assertEqual(
parse_arguments('sum, Int64, Int64'), ('sum', 'Int64', 'Int64')
)
self.assertEqual(
parse_arguments('sum, Nullable(Int64), Int64'),
('sum', 'Nullable(Int64)', 'Int64')
)
self.assertEqual(
parse_arguments('Float32, Decimal(18, 2)'),
('Float32', 'Decimal(18, 2)')
)
self.assertEqual(
parse_arguments('sum, Float32, Decimal(18, 2)'),
('sum', 'Float32', 'Decimal(18, 2)')
)
self.assertEqual(
parse_arguments('quantiles(0.5, 0.9), UInt64'),
('quantiles(0.5, 0.9)', 'UInt64')
)
self.assertEqual(
parse_arguments("sumIf(total, status = 'accepted'), Float32"),
("sumIf(total, status = 'accepted')", "Float32")
)
33 changes: 33 additions & 0 deletions tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,39 @@ def test_create_table_map(self):
'ENGINE = Memory'
)

def test_create_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column('total', types.AggregateFunction(func.sum(), types.UInt32)),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total AggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

@require_server_version(22, 8, 21)
def test_create_simple_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column(
'total', types.SimpleAggregateFunction(
func.sum(), types.UInt32
)
),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total SimpleAggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

def test_table_create_on_cluster(self):
create_sql = (
'CREATE TABLE t1 ON CLUSTER test_cluster '
Expand Down
Loading

0 comments on commit 83c5b46

Please sign in to comment.