Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate types #297

Merged
merged 4 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .compilers.sqlcompiler import ClickHouseSQLCompiler
from .compilers.typecompiler import ClickHouseTypeCompiler
from .reflection import ClickHouseInspector
from .util import get_inner_spec
from .util import get_inner_spec, parse_arguments
from .. import types

# Column specifications
Expand Down Expand Up @@ -54,6 +54,8 @@
'_lowcardinality': types.LowCardinality,
'_tuple': types.Tuple,
'_map': types.Map,
'_aggregatefunction': types.AggregateFunction,
'_simpleaggregatefunction': types.SimpleAggregateFunction,
}


Expand Down Expand Up @@ -230,6 +232,32 @@ def _get_column_type(self, name, spec):
coltype = self.ischema_names['_lowcardinality']
return coltype(self._get_column_type(name, inner))

elif spec.startswith('AggregateFunction'):
params = spec[18:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_aggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('SimpleAggregateFunction'):
params = spec[24:-1]

arguments = parse_arguments(params)
agg_func, inner = arguments[0], arguments[1:]

inner_types = [
self._get_column_type(name, param)
for param in inner
]
coltype = self.ischema_names['_simpleaggregatefunction']
return coltype(agg_func, *inner_types)

elif spec.startswith('Tuple'):
inner = spec[6:-1]
coltype = self.ischema_names['_tuple']
Expand Down
26 changes: 26 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/typecompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,29 @@ def visit_map(self, type_, **kw):
self.process(key_type, **kw),
self.process(value_type, **kw)
)

def visit_aggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "AggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)

def visit_simpleaggregatefunction(self, type_, **kw):
types = [type_api.to_instance(val) for val in type_.nested_types]
type_strings = [self.process(val, **kw) for val in types]

if isinstance(type_.agg_func, str):
agg_str = type_.agg_func
else:
agg_str = str(type_.agg_func.compile(dialect=self.dialect))

return "SimpleAggregateFunction(%s, %s)" % (
agg_str, ", ".join(type_strings)
)
27 changes: 27 additions & 0 deletions clickhouse_sqlalchemy/drivers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,30 @@ def get_inner_spec(spec):
break

return spec[offset + 1:i]


def parse_arguments(param_string):
"""
Given a string of function arguments, parse them into a tuple.
"""
params = []
bracket_level = 0
current_param = ''

for char in param_string:
if char == '(':
bracket_level += 1
elif char == ')':
bracket_level -= 1
elif char == ',' and bracket_level == 0:
params.append(current_param.strip())
current_param = ''
continue

current_param += char

# Append the last parameter
if current_param:
params.append(current_param.strip())

return tuple(params)
4 changes: 4 additions & 0 deletions clickhouse_sqlalchemy/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
'Nested',
'Tuple',
'Map',
'AggregateFunction',
'SimpleAggregateFunction',
]

from .common import String
Expand Down Expand Up @@ -66,6 +68,8 @@
from .common import Decimal
from .common import Tuple
from .common import Map
from .common import AggregateFunction
from .common import SimpleAggregateFunction
from .ip import IPv4
from .ip import IPv6
from .nested import Nested
60 changes: 59 additions & 1 deletion clickhouse_sqlalchemy/types/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy.sql.type_api import to_instance
from typing import Type, Union

from sqlalchemy import types
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.type_api import to_instance


class ClickHouseTypeEngine(types.TypeEngine):
Expand Down Expand Up @@ -37,6 +40,11 @@ def __init__(self, item_type):
self.item_type_impl = to_instance(item_type)
super(Array, self).__init__()

def __repr__(self):
nested_type_str = \
f'{self.item_type_impl.__module__}.{self.item_type_impl!r}'
return f'Array({nested_type_str})'

@property
def python_type(self):
return list
Expand Down Expand Up @@ -73,6 +81,10 @@ def __init__(self, nested_type):
self.nested_type = to_instance(nested_type)
super(LowCardinality, self).__init__()

def __repr__(self):
nested_type_str = f'{self.nested_type.__module__}.{self.nested_type!r}'
return f'LowCardinality({nested_type_str})'


class Int8(Int):
__visit_name__ = 'int8'
Expand Down Expand Up @@ -188,3 +200,49 @@ def __init__(self, key_type, value_type):
self.key_type = key_type
self.value_type = value_type
super(Map, self).__init__()


class AggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'aggregatefunction'

def __init__(
self,
agg_func: Union[Function, str],
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(AggregateFunction, self).__init__()

def __repr__(self) -> str:
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]

if isinstance(self.agg_func, str):
agg_str = self.agg_func
else:
agg_str = f'sa.func.{self.agg_func}'

return f"AggregateFunction({agg_str}, {', '.join(type_strs)})"


class SimpleAggregateFunction(ClickHouseTypeEngine):
__visit_name__ = 'simpleaggregatefunction'

def __init__(
self,
agg_func: Union[Function, str],
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
):
self.agg_func = agg_func
self.nested_types = [to_instance(val) for val in nested_types]
super(SimpleAggregateFunction, self).__init__()

def __repr__(self) -> str:
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]

if isinstance(self.agg_func, str):
agg_str = self.agg_func
else:
agg_str = f'sa.func.{self.agg_func}'

return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})"
20 changes: 20 additions & 0 deletions docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,26 @@ You can specify cluster for materialized view in inner table definition.
{'clickhouse_cluster': 'my_cluster'}
)

Materialized views can also store the aggregated data in a table using the
``AggregatingMergeTree`` engine. The aggregate columns are defined using
``AggregateFunction`` or ``SimpleAggregateFunction``.

.. code-block:: python


# Define storage for Materialized View
class GroupedStatistics(Base):
date = Column(types.Date, primary_key=True)
metric1 = Column(SimpleAggregateFunction(sa.func.sum(), types.Int32), nullable=False)

__table_args__ = (
engines.AggregatingMergeTree(
partition_by=func.toYYYYMM(date),
order_by=(date, )
),
)


Basic DDL support
-----------------

Expand Down
50 changes: 50 additions & 0 deletions tests/drivers/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest import TestCase

from clickhouse_sqlalchemy.drivers.util import get_inner_spec, parse_arguments


class GetInnerSpecTestCase(TestCase):
def test_get_inner_spec(self):
self.assertEqual(
get_inner_spec("DateTime('Europe/Paris')"), "'Europe/Paris'"
)
self.assertEqual(get_inner_spec('Decimal(18, 2)'), "18, 2")
self.assertEqual(get_inner_spec('DateTime64(3)'), "3")


class ParseArgumentsTestCase(TestCase):
def test_parse_arguments(self):
self.assertEqual(
parse_arguments('uniq, UInt64'), ('uniq', 'UInt64')
)
self.assertEqual(
parse_arguments('anyIf, String, UInt8'),
('anyIf', 'String', 'UInt8')
)
self.assertEqual(
parse_arguments('quantiles(0.5, 0.9), UInt64'),
('quantiles(0.5, 0.9)', 'UInt64')
)
self.assertEqual(
parse_arguments('sum, Int64, Int64'), ('sum', 'Int64', 'Int64')
)
self.assertEqual(
parse_arguments('sum, Nullable(Int64), Int64'),
('sum', 'Nullable(Int64)', 'Int64')
)
self.assertEqual(
parse_arguments('Float32, Decimal(18, 2)'),
('Float32', 'Decimal(18, 2)')
)
self.assertEqual(
parse_arguments('sum, Float32, Decimal(18, 2)'),
('sum', 'Float32', 'Decimal(18, 2)')
)
self.assertEqual(
parse_arguments('quantiles(0.5, 0.9), UInt64'),
('quantiles(0.5, 0.9)', 'UInt64')
)
self.assertEqual(
parse_arguments("sumIf(total, status = 'accepted'), Float32"),
("sumIf(total, status = 'accepted')", "Float32")
)
33 changes: 33 additions & 0 deletions tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,39 @@ def test_create_table_map(self):
'ENGINE = Memory'
)

def test_create_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column('total', types.AggregateFunction(func.sum(), types.UInt32)),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total AggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

@require_server_version(22, 8, 21)
def test_create_simple_aggregate_function(self):
table = Table(
't1', self.metadata(),
Column(
'total', types.SimpleAggregateFunction(
func.sum(), types.UInt32
)
),
engines.Memory()
)

self.assertEqual(
self.compile(CreateTable(table)),
'CREATE TABLE t1 ('
'total SimpleAggregateFunction(sum(), UInt32)) '
'ENGINE = Memory'
)

def test_table_create_on_cluster(self):
create_sql = (
'CREATE TABLE t1 ON CLUSTER test_cluster '
Expand Down
Loading