Skip to content

Commit

Permalink
Add support for quantile and quantileIf functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aronbierbaum committed Mar 23, 2024
1 parent d22fa87 commit 5890bf7
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
2 changes: 2 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from sqlalchemy.sql import type_api
from sqlalchemy.util import inspect_getfullargspec

import clickhouse_sqlalchemy.sql.functions # noqa:F401

from ... import types


Expand Down
61 changes: 61 additions & 0 deletions clickhouse_sqlalchemy/sql/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypeVar

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import coercions, roles
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.functions import GenericFunction

from clickhouse_sqlalchemy import types

if TYPE_CHECKING:
from sqlalchemy.sql._typing import _ColumnExpressionArgument

_T = TypeVar('_T', bound=Any)


class quantile(GenericFunction[_T]):
inherit_cache = True

def __init__(
self, level: float, expr: _ColumnExpressionArgument[Any],
condition: _ColumnExpressionArgument[Any] = None, **kwargs: Any
):
arg: ColumnElement[Any] = coercions.expect(
roles.ExpressionElementRole, expr, apply_propagate_attrs=self
)

args = [arg]
if condition is not None:
condition = coercions.expect(
roles.ExpressionElementRole, condition,
apply_propagate_attrs=self
)
args.append(condition)

self.level = level

if isinstance(arg.type, (types.Decimal, types.Float, types.Int)):
return_type = types.Float64
elif isinstance(arg.type, types.DateTime):
return_type = types.DateTime
elif isinstance(arg.type, types.Date):
return_type = types.Date
else:
return_type = types.Float64

kwargs['type_'] = return_type
kwargs['_parsed_args'] = args
super().__init__(arg, **kwargs)


class quantileIf(quantile[_T]):
inherit_cache = True


@compiles(quantile, 'clickhouse')
@compiles(quantileIf, 'clickhouse')
def compile_quantile(element, compiler, **kwargs):
args_str = compiler.function_argspec(element, **kwargs)
return f'{element.name}({element.level}){args_str}'
12 changes: 11 additions & 1 deletion docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,17 @@ Tables created in declarative way have lowercase with words separated by
underscores naming convention. But you can easy set you own via SQLAlchemy
``__tablename__`` attribute.

SQLAlchemy ``func`` proxy for real ClickHouse functions can be also used.

Functions
+++++++++

Many of the ClickHouse functions can be called using the SQLAlchemy ``func``
proxy. A few of aggregate functions require special handling though. There
following functions are supported:

* ``func.quantile(0.5, column1)`` becomes ``quantile(0.5)(column1)``
* ``func.quantileIf(0.5, column1, column2 > 10)`` becomes ``quantileIf(0.5)(column1, column2 > 10)``


Dialect-specific options
++++++++++++++++++++++++
Expand Down
33 changes: 33 additions & 0 deletions tests/sql/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sqlalchemy import Column, func

from clickhouse_sqlalchemy import types, Table

from tests.testcase import CompilationTestCase


class FunctionTestCase(CompilationTestCase):
table = Table(
't1', CompilationTestCase.metadata(),
Column('x', types.Int32, primary_key=True),
Column('time', types.DateTime)
)

def test_quantile(self):
func0 = func.quantile(0.5, self.table.c.x)
self.assertIsInstance(func0.type, types.Float64)
func1 = func.quantile(0.5, self.table.c.time)
self.assertIsInstance(func1.type, types.DateTime)
self.assertEqual(
self.compile(self.session.query(func0)),
'SELECT quantile(0.5)(t1.x) AS quantile_1 FROM t1'
)

func2 = func.quantileIf(0.5, self.table.c.x, self.table.c.x > 10)

self.assertEqual(
self.compile(
self.session.query(func2)
),
'SELECT quantileIf(0.5)(t1.x, t1.x > %(x_1)s) AS ' +
'"quantileIf_1" FROM t1'
)

0 comments on commit 5890bf7

Please sign in to comment.