diff --git a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py index b19319a9..9d385db3 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py @@ -4,6 +4,8 @@ from sqlalchemy.sql import type_api from sqlalchemy.util import inspect_getfullargspec +import clickhouse_sqlalchemy.sql.functions # noqa:F401 + from ... import types diff --git a/clickhouse_sqlalchemy/sql/functions.py b/clickhouse_sqlalchemy/sql/functions.py new file mode 100644 index 00000000..ba84f273 --- /dev/null +++ b/clickhouse_sqlalchemy/sql/functions.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING, TypeVar + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import coercions +from sqlalchemy.sql import roles +from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.sql.functions import GenericFunction + +from clickhouse_sqlalchemy import types + +if TYPE_CHECKING: + from sqlalchemy.sql._typing import _ColumnExpressionArgument + +_T = TypeVar('_T', bound=Any) + + +class quantile(GenericFunction[_T]): + inherit_cache = True + + def __init__( + self, level: float, expr: _ColumnExpressionArgument[Any], + condition: _ColumnExpressionArgument[Any] = None, **kwargs: Any + ): + arg: ColumnElement[Any] = coercions.expect( + roles.ExpressionElementRole, expr, apply_propagate_attrs=self + ) + + args = [arg] + if condition is not None: + condition = coercions.expect( + roles.ExpressionElementRole, condition, + apply_propagate_attrs=self + ) + args.append(condition) + + self.level = level + + if isinstance(arg.type, (types.Decimal, types.Float, types.Int)): + return_type = types.Float64 + elif isinstance(arg.type, types.DateTime): + return_type = types.DateTime + elif isinstance(arg.type, types.Date): + return_type = types.Date + else: + raise RuntimeError(f'Unsupported type {type(arg)}') + + kwargs['type_'] = return_type + kwargs['_parsed_args'] = args + super().__init__(arg, **kwargs) + + +class quantileIf(quantile[_T]): + inherit_cache = True + + +@compiles(quantile, 'clickhouse') +@compiles(quantileIf, 'clickhouse') +def compile_quantile(element, compiler, **kwargs): + args_str = compiler.function_argspec(element, **kwargs) + return f'{element.name}({element.level}){args_str}' diff --git a/docs/features.rst b/docs/features.rst index 75206092..1f4f123b 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -46,7 +46,17 @@ Tables created in declarative way have lowercase with words separated by underscores naming convention. But you can easy set you own via SQLAlchemy ``__tablename__`` attribute. -SQLAlchemy ``func`` proxy for real ClickHouse functions can be also used. + +Functions ++++++++++ + +Many of the ClickHouse functions can be called using the SQLAlchemy ``func`` +proxy. A few of aggregate functions require special handling though. There +following functions are supported: + +* ``func.quantile(0.5, column1)`` becomes ``quantile(0.5)(column1)`` +* ``func.quantileIf(0.5, column1, column2 > 10)`` becomes ``quantileIf(0.5)(column1, column2 > 10)`` + Dialect-specific options ++++++++++++++++++++++++ diff --git a/tests/sql/test_functions.py b/tests/sql/test_functions.py new file mode 100644 index 00000000..bba18679 --- /dev/null +++ b/tests/sql/test_functions.py @@ -0,0 +1,33 @@ +from sqlalchemy import Column, func + +from clickhouse_sqlalchemy import types, Table + +from tests.testcase import CompilationTestCase + + +class FunctionTestCase(CompilationTestCase): + table = Table( + 't1', CompilationTestCase.metadata(), + Column('x', types.Int32, primary_key=True), + Column('time', types.DateTime) + ) + + def test_quantile(self): + func0 = func.quantile(0.5, self.table.c.x) + self.assertIsInstance(func0.type, types.Float64) + func1 = func.quantile(0.5, self.table.c.time) + self.assertIsInstance(func1.type, types.DateTime) + self.assertEqual( + self.compile(self.session.query(func0)), + 'SELECT quantile(0.5)(t1.x) AS quantile_1 FROM t1' + ) + + func2 = func.quantileIf(0.5, self.table.c.x, self.table.c.x > 10) + + self.assertEqual( + self.compile( + self.session.query(func2) + ), + 'SELECT quantileIf(0.5)(t1.x, t1.x > %(x_1)s) AS ' + + '"quantileIf_1" FROM t1' + )