diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 78a12498..216f5af0 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -13,6 +13,7 @@ jobs: - "3.11" - "3.12" clickhouse-version: + - 24.1.8.22 - 23.8.4.69 - 22.5.1.2079 - 19.3.5 @@ -77,7 +78,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Finished - uses: coverallsapp/github-action@v2.2.3 + uses: coverallsapp/github-action@v2.3.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} parallel-finished: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 2065c17a..ce234955 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ ## [Unreleased] +## [0.3.2] - 2024-06-12 +### Added +- ``quantile`` and ``quantileIf`` functions. Pull request [#303](https://github.com/xzkostyan/clickhouse-sqlalchemy/pull/303) by [aronbierbaum](https://github.com/aronbierbaum). +- ``AggregateFunction`` and ``SimpleAggregateFunction`` aggregate types. Pull request [#297](https://github.com/xzkostyan/clickhouse-sqlalchemy/pull/297) by [aronbierbaum](https://github.com/aronbierbaum). +- Date32 type. Pull request [#307](https://github.com/xzkostyan/clickhouse-sqlalchemy/pull/307) by [BTheunissen](https://github.com/BTheunissen). Solves issue [#302](https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/302). + +### Fixed +- Broken nested Map types. Pull request [#315](https://github.com/xzkostyan/clickhouse-sqlalchemy/pull/315) by [aksenof](https://github.com/aksenof). Solves issue [#314](https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/314). + ## [0.3.1] - 2024-03-14 ### Added - ``SETTINGS`` clause. Pull request [#292](https://github.com/xzkostyan/clickhouse-sqlalchemy/pull/292) by [limonyellow](https://github.com/limonyellow). @@ -344,7 +353,8 @@ Log, TinyLog, Null. - Chunked `INSERT INTO` in one request. - Engines: MergeTree, CollapsingMergeTree, SummingMergeTree, Buffer, Memory. -[Unreleased]: https://github.com/xzkostyan/clickhouse-sqlalchemy/compare/0.3.1...HEAD +[Unreleased]: https://github.com/xzkostyan/clickhouse-sqlalchemy/compare/0.3.2...HEAD +[0.3.2]: https://github.com/xzkostyan/clickhouse-sqlalchemy/compare/0.3.1...0.3.2 [0.3.1]: https://github.com/xzkostyan/clickhouse-sqlalchemy/compare/0.3.0...0.3.1 [0.3.0]: https://github.com/xzkostyan/clickhouse-sqlalchemy/compare/0.2.5...0.3.0 [0.2.5]: https://github.com/xzkostyan/clickhouse-sqlalchemy/compare/0.2.4...0.2.5 diff --git a/clickhouse_sqlalchemy/__init__.py b/clickhouse_sqlalchemy/__init__.py index 1943cc4e..8f5893c7 100644 --- a/clickhouse_sqlalchemy/__init__.py +++ b/clickhouse_sqlalchemy/__init__.py @@ -4,7 +4,7 @@ from .sql import Table, MaterializedView, select -VERSION = (0, 3, 1) +VERSION = (0, 3, 2) __version__ = '.'.join(str(x) for x in VERSION) diff --git a/clickhouse_sqlalchemy/drivers/asynch/base.py b/clickhouse_sqlalchemy/drivers/asynch/base.py index 0f757f9e..5a28ea6e 100644 --- a/clickhouse_sqlalchemy/drivers/asynch/base.py +++ b/clickhouse_sqlalchemy/drivers/asynch/base.py @@ -4,17 +4,24 @@ from sqlalchemy.pool import AsyncAdaptedQueuePool from .connector import AsyncAdapt_asynch_dbapi -from ..native.base import ClickHouseDialect_native +from ..native.base import ClickHouseDialect_native, ClickHouseExecutionContext # Export connector version VERSION = (0, 0, 1, None) +class ClickHouseAsynchExecutionContext(ClickHouseExecutionContext): + def create_server_side_cursor(self): + return self.create_default_cursor() + + class ClickHouseDialect_asynch(ClickHouseDialect_native): driver = 'asynch' + execution_ctx_cls = ClickHouseAsynchExecutionContext is_async = True supports_statement_cache = True + supports_server_side_cursors = True @classmethod def import_dbapi(cls): diff --git a/clickhouse_sqlalchemy/drivers/base.py b/clickhouse_sqlalchemy/drivers/base.py index 79855709..7e102df1 100644 --- a/clickhouse_sqlalchemy/drivers/base.py +++ b/clickhouse_sqlalchemy/drivers/base.py @@ -13,7 +13,7 @@ from .compilers.sqlcompiler import ClickHouseSQLCompiler from .compilers.typecompiler import ClickHouseTypeCompiler from .reflection import ClickHouseInspector -from .util import get_inner_spec +from .util import get_inner_spec, parse_arguments from .. import types # Column specifications @@ -35,6 +35,7 @@ 'UInt16': types.UInt16, 'UInt8': types.UInt8, 'Date': types.Date, + 'Date32': types.Date32, 'DateTime': types.DateTime, 'DateTime64': types.DateTime64, 'Float64': types.Float64, @@ -49,11 +50,14 @@ 'FixedString': types.String, 'Enum8': types.Enum8, 'Enum16': types.Enum16, + 'Object(\'json\')': types.JSON, '_array': types.Array, '_nullable': types.Nullable, '_lowcardinality': types.LowCardinality, '_tuple': types.Tuple, '_map': types.Map, + '_aggregatefunction': types.AggregateFunction, + '_simpleaggregatefunction': types.SimpleAggregateFunction, } @@ -132,6 +136,16 @@ class ClickHouseDialect(default.DefaultDialect): inspector = ClickHouseInspector + def __init__( + self, + json_serializer=None, + json_deserializer=None, + **kwargs, + ): + default.DefaultDialect.__init__(self, **kwargs) + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer + def initialize(self, connection): super(ClickHouseDialect, self).initialize(connection) @@ -230,6 +244,32 @@ def _get_column_type(self, name, spec): coltype = self.ischema_names['_lowcardinality'] return coltype(self._get_column_type(name, inner)) + elif spec.startswith('AggregateFunction'): + params = spec[18:-1] + + arguments = parse_arguments(params) + agg_func, inner = arguments[0], arguments[1:] + + inner_types = [ + self._get_column_type(name, param) + for param in inner + ] + coltype = self.ischema_names['_aggregatefunction'] + return coltype(agg_func, *inner_types) + + elif spec.startswith('SimpleAggregateFunction'): + params = spec[24:-1] + + arguments = parse_arguments(params) + agg_func, inner = arguments[0], arguments[1:] + + inner_types = [ + self._get_column_type(name, param) + for param in inner + ] + coltype = self.ischema_names['_simpleaggregatefunction'] + return coltype(agg_func, *inner_types) + elif spec.startswith('Tuple'): inner = spec[6:-1] coltype = self.ischema_names['_tuple'] @@ -244,7 +284,7 @@ def _get_column_type(self, name, spec): coltype = self.ischema_names['_map'] inner_types = [ self._get_column_type(name, t.strip()) - for t in inner.split(',') + for t in inner.split(',', 1) ] return coltype(*inner_types) diff --git a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py index 69d6ceac..45c8fe49 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py @@ -4,6 +4,8 @@ from sqlalchemy.sql import type_api from sqlalchemy.util import inspect_getfullargspec +import clickhouse_sqlalchemy.sql.functions # noqa:F401 + from ... import types diff --git a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py index 26647841..0c5ab472 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py @@ -60,6 +60,9 @@ def visit_uint256(self, type_, **kw): def visit_date(self, type_, **kw): return 'Date' + def visit_date32(self, type_, **kw): + return 'Date32' + def visit_datetime(self, type_, **kw): if type_.timezone: return "DateTime('%s')" % type_.timezone @@ -84,6 +87,9 @@ def visit_numeric(self, type_, **kw): def visit_boolean(self, type_, **kw): return 'Bool' + def visit_json(self, type_, **kw): + return 'JSON' + def visit_nested(self, nested, **kwargs): ddl_compiler = self.dialect.ddl_compiler(self.dialect, None) cols_create = [ @@ -118,10 +124,26 @@ def visit_ipv6(self, type_, **kw): return 'IPv6' def visit_tuple(self, type_, **kw): - cols = ( - self.process(type_api.to_instance(nested_type), **kw) + cols = [] + is_named_type = all([ + isinstance(nested_type, tuple) and len(nested_type) == 2 for nested_type in type_.nested_types - ) + ]) + if is_named_type: + for nested_type in type_.nested_types: + name = nested_type[0] + name_type = nested_type[1] + inner_type = self.process( + type_api.to_instance(name_type), + **kw + ) + cols.append( + f'{name} {inner_type}') + else: + cols = ( + self.process(type_api.to_instance(nested_type), **kw) + for nested_type in type_.nested_types + ) return 'Tuple(%s)' % ', '.join(cols) def visit_map(self, type_, **kw): @@ -131,3 +153,29 @@ def visit_map(self, type_, **kw): self.process(key_type, **kw), self.process(value_type, **kw) ) + + def visit_aggregatefunction(self, type_, **kw): + types = [type_api.to_instance(val) for val in type_.nested_types] + type_strings = [self.process(val, **kw) for val in types] + + if isinstance(type_.agg_func, str): + agg_str = type_.agg_func + else: + agg_str = str(type_.agg_func.compile(dialect=self.dialect)) + + return "AggregateFunction(%s, %s)" % ( + agg_str, ", ".join(type_strings) + ) + + def visit_simpleaggregatefunction(self, type_, **kw): + types = [type_api.to_instance(val) for val in type_.nested_types] + type_strings = [self.process(val, **kw) for val in types] + + if isinstance(type_.agg_func, str): + agg_str = type_.agg_func + else: + agg_str = str(type_.agg_func.compile(dialect=self.dialect)) + + return "SimpleAggregateFunction(%s, %s)" % ( + agg_str, ", ".join(type_strings) + ) diff --git a/clickhouse_sqlalchemy/drivers/http/escaper.py b/clickhouse_sqlalchemy/drivers/http/escaper.py index ab259f10..9c572942 100644 --- a/clickhouse_sqlalchemy/drivers/http/escaper.py +++ b/clickhouse_sqlalchemy/drivers/http/escaper.py @@ -1,6 +1,7 @@ from datetime import date, datetime from decimal import Decimal import enum +import uuid class Escaper(object): @@ -49,6 +50,9 @@ def escape_datetime64(self, item): def escape_decimal(self, item): return float(item) + def escape_uuid(self, item): + return str(item) + def escape_item(self, item): if item is None: return 'NULL' @@ -75,5 +79,7 @@ def escape_item(self, item): ) + "}" elif isinstance(item, enum.Enum): return self.escape_string(item.name) + elif isinstance(item, uuid.UUID): + return self.escape_uuid(item) else: raise Exception("Unsupported object {}".format(item)) diff --git a/clickhouse_sqlalchemy/drivers/native/connector.py b/clickhouse_sqlalchemy/drivers/native/connector.py index 01481efa..94f4624a 100644 --- a/clickhouse_sqlalchemy/drivers/native/connector.py +++ b/clickhouse_sqlalchemy/drivers/native/connector.py @@ -138,7 +138,8 @@ def _prepare(self, context=None): execute_kwargs = { 'settings': settings, 'external_tables': external_tables, - 'types_check': execution_options.get('types_check', False) + 'types_check': execution_options.get('types_check', False), + 'query_id': execution_options.get('query_id', None) } return execute, execute_kwargs diff --git a/clickhouse_sqlalchemy/drivers/util.py b/clickhouse_sqlalchemy/drivers/util.py index 25783daf..ca5ca887 100644 --- a/clickhouse_sqlalchemy/drivers/util.py +++ b/clickhouse_sqlalchemy/drivers/util.py @@ -16,3 +16,30 @@ def get_inner_spec(spec): break return spec[offset + 1:i] + + +def parse_arguments(param_string): + """ + Given a string of function arguments, parse them into a tuple. + """ + params = [] + bracket_level = 0 + current_param = '' + + for char in param_string: + if char == '(': + bracket_level += 1 + elif char == ')': + bracket_level -= 1 + elif char == ',' and bracket_level == 0: + params.append(current_param.strip()) + current_param = '' + continue + + current_param += char + + # Append the last parameter + if current_param: + params.append(current_param.strip()) + + return tuple(params) diff --git a/clickhouse_sqlalchemy/sql/functions.py b/clickhouse_sqlalchemy/sql/functions.py new file mode 100644 index 00000000..b4cd3ed5 --- /dev/null +++ b/clickhouse_sqlalchemy/sql/functions.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import coercions, roles +from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.sql.functions import GenericFunction + +from clickhouse_sqlalchemy import types + +if TYPE_CHECKING: + from sqlalchemy.sql._typing import _ColumnExpressionArgument + +_T = TypeVar('_T', bound=Any) + + +class quantile(GenericFunction[_T]): + inherit_cache = True + + def __init__( + self, level: float, expr: _ColumnExpressionArgument[Any], + condition: _ColumnExpressionArgument[Any] = None, **kwargs: Any + ): + arg: ColumnElement[Any] = coercions.expect( + roles.ExpressionElementRole, expr, apply_propagate_attrs=self + ) + + args = [arg] + if condition is not None: + condition = coercions.expect( + roles.ExpressionElementRole, condition, + apply_propagate_attrs=self + ) + args.append(condition) + + self.level = level + + if isinstance(arg.type, (types.Decimal, types.Float, types.Int)): + return_type = types.Float64 + elif isinstance(arg.type, types.DateTime): + return_type = types.DateTime + elif isinstance(arg.type, types.Date): + return_type = types.Date + else: + return_type = types.Float64 + + kwargs['type_'] = return_type + kwargs['_parsed_args'] = args + super().__init__(arg, **kwargs) + + +class quantileIf(quantile[_T]): + inherit_cache = True + + +@compiles(quantile, 'clickhouse') +@compiles(quantileIf, 'clickhouse') +def compile_quantile(element, compiler, **kwargs): + args_str = compiler.function_argspec(element, **kwargs) + return f'{element.name}({element.level}){args_str}' diff --git a/clickhouse_sqlalchemy/types/__init__.py b/clickhouse_sqlalchemy/types/__init__.py index 502e8a0f..67747f88 100644 --- a/clickhouse_sqlalchemy/types/__init__.py +++ b/clickhouse_sqlalchemy/types/__init__.py @@ -22,6 +22,7 @@ 'Float32', 'Float64', 'Date', + 'Date32', 'DateTime', 'DateTime64', 'Enum', @@ -30,9 +31,12 @@ 'Decimal', 'IPv4', 'IPv6', + 'JSON', 'Nested', 'Tuple', 'Map', + 'AggregateFunction', + 'SimpleAggregateFunction', ] from .common import String @@ -58,14 +62,18 @@ from .common import Float32 from .common import Float64 from .common import Date +from .common import Date32 from .common import DateTime from .common import DateTime64 from .common import Enum from .common import Enum8 from .common import Enum16 from .common import Decimal +from .common import JSON from .common import Tuple from .common import Map +from .common import AggregateFunction +from .common import SimpleAggregateFunction from .ip import IPv4 from .ip import IPv6 from .nested import Nested diff --git a/clickhouse_sqlalchemy/types/common.py b/clickhouse_sqlalchemy/types/common.py index 947a8708..458a320c 100644 --- a/clickhouse_sqlalchemy/types/common.py +++ b/clickhouse_sqlalchemy/types/common.py @@ -1,5 +1,8 @@ -from sqlalchemy.sql.type_api import to_instance +from typing import Type, Union + from sqlalchemy import types +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.type_api import to_instance class ClickHouseTypeEngine(types.TypeEngine): @@ -27,6 +30,10 @@ class Boolean(types.Boolean, ClickHouseTypeEngine): pass +class JSON(types.JSON, ClickHouseTypeEngine): + __visit_name__ = 'json' + + class Array(ClickHouseTypeEngine): __visit_name__ = 'array' @@ -37,6 +44,11 @@ def __init__(self, item_type): self.item_type_impl = to_instance(item_type) super(Array, self).__init__() + def __repr__(self): + nested_type_str = \ + f'{self.item_type_impl.__module__}.{self.item_type_impl!r}' + return f'Array({nested_type_str})' + @property def python_type(self): return list @@ -73,6 +85,10 @@ def __init__(self, nested_type): self.nested_type = to_instance(nested_type) super(LowCardinality, self).__init__() + def __repr__(self): + nested_type_str = f'{self.nested_type.__module__}.{self.nested_type!r}' + return f'LowCardinality({nested_type_str})' + class Int8(Int): __visit_name__ = 'int8' @@ -134,6 +150,10 @@ class Date(types.Date, ClickHouseTypeEngine): __visit_name__ = 'date' +class Date32(types.Date, ClickHouseTypeEngine): + __visit_name__ = 'date32' + + class DateTime(types.DateTime, ClickHouseTypeEngine): __visit_name__ = 'datetime' @@ -188,3 +208,49 @@ def __init__(self, key_type, value_type): self.key_type = key_type self.value_type = value_type super(Map, self).__init__() + + +class AggregateFunction(ClickHouseTypeEngine): + __visit_name__ = 'aggregatefunction' + + def __init__( + self, + agg_func: Union[Function, str], + *nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine], + ): + self.agg_func = agg_func + self.nested_types = [to_instance(val) for val in nested_types] + super(AggregateFunction, self).__init__() + + def __repr__(self) -> str: + type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types] + + if isinstance(self.agg_func, str): + agg_str = self.agg_func + else: + agg_str = f'sa.func.{self.agg_func}' + + return f"AggregateFunction({agg_str}, {', '.join(type_strs)})" + + +class SimpleAggregateFunction(ClickHouseTypeEngine): + __visit_name__ = 'simpleaggregatefunction' + + def __init__( + self, + agg_func: Union[Function, str], + *nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine], + ): + self.agg_func = agg_func + self.nested_types = [to_instance(val) for val in nested_types] + super(SimpleAggregateFunction, self).__init__() + + def __repr__(self) -> str: + type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types] + + if isinstance(self.agg_func, str): + agg_str = self.agg_func + else: + agg_str = f'sa.func.{self.agg_func}' + + return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})" diff --git a/docs/features.rst b/docs/features.rst index 75206092..20dce7e8 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -46,7 +46,17 @@ Tables created in declarative way have lowercase with words separated by underscores naming convention. But you can easy set you own via SQLAlchemy ``__tablename__`` attribute. -SQLAlchemy ``func`` proxy for real ClickHouse functions can be also used. + +Functions ++++++++++ + +Many of the ClickHouse functions can be called using the SQLAlchemy ``func`` +proxy. A few of aggregate functions require special handling though. There +following functions are supported: + +* ``func.quantile(0.5, column1)`` becomes ``quantile(0.5)(column1)`` +* ``func.quantileIf(0.5, column1, column2 > 10)`` becomes ``quantileIf(0.5)(column1, column2 > 10)`` + Dialect-specific options ++++++++++++++++++++++++ @@ -589,6 +599,26 @@ You can specify cluster for materialized view in inner table definition. {'clickhouse_cluster': 'my_cluster'} ) +Materialized views can also store the aggregated data in a table using the +``AggregatingMergeTree`` engine. The aggregate columns are defined using +``AggregateFunction`` or ``SimpleAggregateFunction``. + + .. code-block:: python + + + # Define storage for Materialized View + class GroupedStatistics(Base): + date = Column(types.Date, primary_key=True) + metric1 = Column(SimpleAggregateFunction(sa.func.sum(), types.Int32), nullable=False) + + __table_args__ = ( + engines.AggregatingMergeTree( + partition_by=func.toYYYYMM(date), + order_by=(date, ) + ), + ) + + Basic DDL support ----------------- diff --git a/tests/drivers/asynch/test_cursor.py b/tests/drivers/asynch/test_cursor.py index c2a723be..91b83cb3 100644 --- a/tests/drivers/asynch/test_cursor.py +++ b/tests/drivers/asynch/test_cursor.py @@ -2,11 +2,9 @@ from sqlalchemy.util.concurrency import greenlet_spawn from tests.testcase import AsynchSessionTestCase -from tests.util import run_async class CursorTestCase(AsynchSessionTestCase): - @run_async async def test_execute_without_context(self): raw = await self.session.bind.raw_connection() cur = await greenlet_spawn(lambda: raw.cursor()) @@ -20,7 +18,6 @@ async def test_execute_without_context(self): raw.close() - @run_async async def test_execute_with_context(self): rv = await self.session.execute( text('SELECT * FROM system.numbers LIMIT 1') @@ -28,10 +25,21 @@ async def test_execute_with_context(self): self.assertEqual(len(rv.fetchall()), 1) - @run_async async def test_check_iter_cursor(self): rv = await self.session.execute( text('SELECT number FROM system.numbers LIMIT 5') ) self.assertListEqual(list(rv), [(x,) for x in range(5)]) + + async def test_execute_with_stream(self): + async with self.connection.stream( + text("SELECT * FROM system.numbers LIMIT 10"), + execution_options={'max_block_size': 1} + ) as result: + idx = 0 + async for r in result: + self.assertEqual(r[0], idx) + idx += 1 + + self.assertEqual(idx, 10) diff --git a/tests/drivers/asynch/test_insert.py b/tests/drivers/asynch/test_insert.py index e69737f6..0b897031 100644 --- a/tests/drivers/asynch/test_insert.py +++ b/tests/drivers/asynch/test_insert.py @@ -4,11 +4,9 @@ from asynch.errors import TypeMismatchError from tests.testcase import AsynchSessionTestCase -from tests.util import run_async class NativeInsertTestCase(AsynchSessionTestCase): - @run_async async def test_rowcount_return1(self): metadata = self.metadata() table = Table( @@ -37,7 +35,6 @@ async def test_rowcount_return1(self): ) self.assertEqual(rv.rowcount, -1) - @run_async async def test_types_check(self): metadata = self.metadata() table = Table( diff --git a/tests/drivers/asynch/test_select.py b/tests/drivers/asynch/test_select.py index 86fb4151..b1536c65 100644 --- a/tests/drivers/asynch/test_select.py +++ b/tests/drivers/asynch/test_select.py @@ -3,12 +3,9 @@ from clickhouse_sqlalchemy import engines, types, Table from tests.session import mocked_engine from tests.testcase import AsynchSessionTestCase -from tests.util import run_async class SanityTestCase(AsynchSessionTestCase): - - @run_async async def test_sanity(self): with mocked_engine(self.session) as engine: metadata = self.metadata() diff --git a/tests/drivers/http/test_escaping.py b/tests/drivers/http/test_escaping.py index a456ecc9..ef3057e7 100644 --- a/tests/drivers/http/test_escaping.py +++ b/tests/drivers/http/test_escaping.py @@ -1,5 +1,6 @@ from decimal import Decimal from datetime import date +import uuid from sqlalchemy import Column, literal @@ -29,6 +30,10 @@ def test_escaper(self): self.assertEqual(e.escape([10.0]), '[10.0]') self.assertEqual(e.escape([date(2017, 1, 2)]), "['2017-01-02']") self.assertEqual(e.escape(dict(x=10, y=20)), {'x': 10, 'y': 20}) + self.assertEqual( + e.escape([uuid.UUID("ef3e3d4b-c782-4993-83fc-894ff0aba8ff")]), + '[ef3e3d4b-c782-4993-83fc-894ff0aba8ff]' + ) with self.assertRaises(Exception) as ex: e.escape([object()]) diff --git a/tests/drivers/native/test_cursor.py b/tests/drivers/native/test_cursor.py index e38274df..c4d84088 100644 --- a/tests/drivers/native/test_cursor.py +++ b/tests/drivers/native/test_cursor.py @@ -1,3 +1,5 @@ +import uuid + from sqlalchemy import text from tests.testcase import NativeSessionTestCase @@ -47,3 +49,14 @@ def test_with_settings_in_execution_options(self): dict(rv.context.execution_options), {"settings": {"final": 1}} ) self.assertEqual(len(rv.fetchall()), 1) + + def test_set_query_id(self): + query_id = str(uuid.uuid4()) + rv = self.session.execute( + text( + f"SELECT query_id " + f"FROM system.processes " + f"WHERE query_id = '{query_id}'" + ), execution_options={'query_id': query_id} + ) + self.assertEqual(rv.fetchall()[0][0], query_id) diff --git a/tests/drivers/test_clickhouse_dialect.py b/tests/drivers/test_clickhouse_dialect.py index 122b0ec4..f7c439e2 100644 --- a/tests/drivers/test_clickhouse_dialect.py +++ b/tests/drivers/test_clickhouse_dialect.py @@ -86,7 +86,7 @@ def test_get_view_names_with_schema(self): db_views = self.dialect.get_view_names(self.connection, test_database) self.assertNotIn(self.table.name, db_views) - def test_reflecttable(self): + def test_reflect_table(self): self.table.create(self.session.bind) meta = self.metadata() insp = inspect(self.session.bind) @@ -95,7 +95,7 @@ def test_reflecttable(self): self.assertEqual(self.table.name, reflected_table.name) - def test_reflecttable_with_schema(self): + def test_reflect_table_with_schema(self): # Imitates calling sequence for clients like Superset that look # across schemas. meta = self.metadata() @@ -146,9 +146,8 @@ class ClickHouseAsynchDialectTestCase(BaseAsynchTestCase): session = asynch_session - @run_async - async def setUp(self): - super().setUp() + def setUp(self): + super(ClickHouseAsynchDialectTestCase, self).setUp() self.test_metadata = self.metadata() self.table = Table( 'test_exists_table', @@ -156,12 +155,11 @@ async def setUp(self): Column('x', types.Int32, primary_key=True), engines.Memory() ) - await self.run_sync(self.test_metadata.drop_all) + run_async(self.connection.run_sync)(self.test_metadata.drop_all) - @run_async - async def tearDown(self): - await self.run_sync(self.test_metadata.drop_all) - super().tearDown() + def tearDown(self): + run_async(self.connection.run_sync)(self.test_metadata.drop_all) + super(ClickHouseAsynchDialectTestCase, self).tearDown() async def run_inspector_method(self, method, *args, **kwargs): def _run(conn): @@ -170,7 +168,6 @@ def _run(conn): return await self.run_sync(_run) - @run_async async def test_has_table(self): self.assertFalse( await self.run_inspector_method('has_table', self.table.name) @@ -182,7 +179,6 @@ async def test_has_table(self): await self.run_inspector_method('has_table', self.table.name) ) - @run_async async def test_has_table_with_schema(self): self.assertFalse( await self.run_inspector_method( @@ -199,7 +195,6 @@ async def test_has_table_with_schema(self): ) ) - @run_async async def test_get_table_names(self): await self.run_sync(self.test_metadata.create_all) @@ -207,7 +202,6 @@ async def test_get_table_names(self): self.assertIn(self.table.name, db_tables) - @run_async async def test_get_table_names_with_schema(self): await self.run_sync(self.test_metadata.create_all) @@ -218,7 +212,6 @@ async def test_get_table_names_with_schema(self): self.assertIn('columns', db_tables) - @run_async async def test_get_view_names(self): await self.run_sync(self.test_metadata.create_all) @@ -226,7 +219,6 @@ async def test_get_view_names(self): self.assertNotIn(self.table.name, db_views) - @run_async async def test_get_view_names_with_schema(self): await self.run_sync(self.test_metadata.create_all) @@ -237,8 +229,7 @@ async def test_get_view_names_with_schema(self): self.assertNotIn(self.table.name, db_views) - @run_async - async def test_reflecttable(self): + async def test_reflect_table(self): await self.run_sync(self.test_metadata.create_all) meta = self.metadata() @@ -247,8 +238,7 @@ async def test_reflecttable(self): self.assertEqual(self.table.name, reflected_table.name) - @run_async - async def test_reflecttable_with_schema(self): + async def test_reflect_table_with_schema(self): # Imitates calling sequence for clients like Superset that look # across schemas. meta = self.metadata() @@ -260,17 +250,15 @@ async def test_reflecttable_with_schema(self): if self.server_version >= (18, 16, 0): self.assertIsNone(reflected_table.engine) - @run_async async def test_get_schema_names(self): schemas = await self.run_inspector_method('get_schema_names') self.assertIn(test_database, schemas) - def test_columns_compilation(self): + async def test_columns_compilation(self): # should not raise UnsupportedCompilationError col = Column('x', types.Nullable(types.Int32)) self.assertEqual(str(col.type), 'Nullable(Int32)') - @run_async @require_server_version(19, 16, 2, is_async=True) async def test_empty_set_expr(self): numbers = Table( @@ -310,7 +298,8 @@ def test_server_version_http(self): def test_server_version_native(self): return self._test_server_version_any(system_native_uri) - @run_async + +class CachedServerVersionAsyncTestCase(BaseAsynchTestCase): async def test_server_version_asynch(self): engine_session = make_session(create_async_engine( system_asynch_uri, diff --git a/tests/drivers/test_util.py b/tests/drivers/test_util.py new file mode 100644 index 00000000..daab91a0 --- /dev/null +++ b/tests/drivers/test_util.py @@ -0,0 +1,50 @@ +from unittest import TestCase + +from clickhouse_sqlalchemy.drivers.util import get_inner_spec, parse_arguments + + +class GetInnerSpecTestCase(TestCase): + def test_get_inner_spec(self): + self.assertEqual( + get_inner_spec("DateTime('Europe/Paris')"), "'Europe/Paris'" + ) + self.assertEqual(get_inner_spec('Decimal(18, 2)'), "18, 2") + self.assertEqual(get_inner_spec('DateTime64(3)'), "3") + + +class ParseArgumentsTestCase(TestCase): + def test_parse_arguments(self): + self.assertEqual( + parse_arguments('uniq, UInt64'), ('uniq', 'UInt64') + ) + self.assertEqual( + parse_arguments('anyIf, String, UInt8'), + ('anyIf', 'String', 'UInt8') + ) + self.assertEqual( + parse_arguments('quantiles(0.5, 0.9), UInt64'), + ('quantiles(0.5, 0.9)', 'UInt64') + ) + self.assertEqual( + parse_arguments('sum, Int64, Int64'), ('sum', 'Int64', 'Int64') + ) + self.assertEqual( + parse_arguments('sum, Nullable(Int64), Int64'), + ('sum', 'Nullable(Int64)', 'Int64') + ) + self.assertEqual( + parse_arguments('Float32, Decimal(18, 2)'), + ('Float32', 'Decimal(18, 2)') + ) + self.assertEqual( + parse_arguments('sum, Float32, Decimal(18, 2)'), + ('sum', 'Float32', 'Decimal(18, 2)') + ) + self.assertEqual( + parse_arguments('quantiles(0.5, 0.9), UInt64'), + ('quantiles(0.5, 0.9)', 'UInt64') + ) + self.assertEqual( + parse_arguments("sumIf(total, status = 'accepted'), Float32"), + ("sumIf(total, status = 'accepted')", "Float32") + ) diff --git a/tests/sql/test_functions.py b/tests/sql/test_functions.py new file mode 100644 index 00000000..bba18679 --- /dev/null +++ b/tests/sql/test_functions.py @@ -0,0 +1,33 @@ +from sqlalchemy import Column, func + +from clickhouse_sqlalchemy import types, Table + +from tests.testcase import CompilationTestCase + + +class FunctionTestCase(CompilationTestCase): + table = Table( + 't1', CompilationTestCase.metadata(), + Column('x', types.Int32, primary_key=True), + Column('time', types.DateTime) + ) + + def test_quantile(self): + func0 = func.quantile(0.5, self.table.c.x) + self.assertIsInstance(func0.type, types.Float64) + func1 = func.quantile(0.5, self.table.c.time) + self.assertIsInstance(func1.type, types.DateTime) + self.assertEqual( + self.compile(self.session.query(func0)), + 'SELECT quantile(0.5)(t1.x) AS quantile_1 FROM t1' + ) + + func2 = func.quantileIf(0.5, self.table.c.x, self.table.c.x > 10) + + self.assertEqual( + self.compile( + self.session.query(func2) + ), + 'SELECT quantileIf(0.5)(t1.x, t1.x > %(x_1)s) AS ' + + '"quantileIf_1" FROM t1' + ) diff --git a/tests/sql/test_selectable.py b/tests/sql/test_selectable.py index 1fa3fd75..975853d5 100644 --- a/tests/sql/test_selectable.py +++ b/tests/sql/test_selectable.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, and_ +from sqlalchemy import Column, and_, func from sqlalchemy.exc import CompileError from sqlalchemy.sql import expression @@ -585,3 +585,49 @@ def test_distinct_on(self): self.compile(query), 'SELECT DISTINCT ON (t1.x) t1.x FROM t1' ) + + def test_map_type(self): + table = self._make_table( + 't1', + Column('x', types.Int32, primary_key=True), + Column('y', types.Map(types.String, types.String)) + ) + + query = select( + func.mapKeys(table.c.y), func.mapValues(table.c.y) + ).where( + func.has(table.c.y, 'foo') + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT mapKeys(t1.y) AS "mapKeys_1", ' + 'mapValues(t1.y) AS "mapValues_1" ' + 'FROM t1 ' + 'WHERE has(t1.y, \'foo\')' + ) + + def test_nested_map_type(self): + table = self._make_table( + 't1', + Column('x', types.Int32, primary_key=True), + Column( + 'y', + types.Map( + types.String, + types.Map(types.String, types.String) + ) + ) + ) + + query = select( + func.mapKeys(table.c.y), func.mapValues(table.c.y) + ).where( + func.has(table.c.y, 'foo') + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT mapKeys(t1.y) AS "mapKeys_1", ' + 'mapValues(t1.y) AS "mapValues_1" ' + 'FROM t1 ' + 'WHERE has(t1.y, \'foo\')' + ) diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 10f8b932..dfd9f549 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -296,6 +296,26 @@ def test_create_table_tuple(self): 'ENGINE = Memory' ) + def test_create_table_named_tuple(self): + table = Table( + 't1', self.metadata(), + Column( + 'x', + types.Tuple( + ('name', types.String), + ('value', types.Float32) + ) + ), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'x Tuple(name String, value Float32)) ' + 'ENGINE = Memory' + ) + @require_server_version(21, 1, 3) def test_create_table_map(self): table = Table( @@ -311,6 +331,39 @@ def test_create_table_map(self): 'ENGINE = Memory' ) + def test_create_aggregate_function(self): + table = Table( + 't1', self.metadata(), + Column('total', types.AggregateFunction(func.sum(), types.UInt32)), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'total AggregateFunction(sum(), UInt32)) ' + 'ENGINE = Memory' + ) + + @require_server_version(22, 8, 21) + def test_create_simple_aggregate_function(self): + table = Table( + 't1', self.metadata(), + Column( + 'total', types.SimpleAggregateFunction( + func.sum(), types.UInt32 + ) + ), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE t1 (' + 'total SimpleAggregateFunction(sum(), UInt32)) ' + 'ENGINE = Memory' + ) + def test_table_create_on_cluster(self): create_sql = ( 'CREATE TABLE t1 ON CLUSTER test_cluster ' diff --git a/tests/test_reflection.py b/tests/test_reflection.py index 4c5a3f41..50ddbb90 100644 --- a/tests/test_reflection.py +++ b/tests/test_reflection.py @@ -1,7 +1,8 @@ import enum -from sqlalchemy import Column, inspect, types as sa_types +from sqlalchemy import Column, func, inspect, types as sa_types from clickhouse_sqlalchemy import types, engines, Table + from tests.testcase import BaseTestCase from tests.util import require_server_version, with_native_and_http_sessions @@ -166,3 +167,65 @@ def test_datetime(self): self.assertIsInstance(coltype, types.DateTime) self.assertIsNone(coltype.timezone) + + def test_aggregate_function(self): + coltype = self._type_round_trip( + types.AggregateFunction(func.sum(), types.UInt16) + )[0]['type'] + + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt16) + + coltype = self._type_round_trip( + types.AggregateFunction('quantiles(0.5, 0.9)', types.UInt32) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'quantiles(0.5, 0.9)') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt32) + + coltype = self._type_round_trip( + types.AggregateFunction( + func.argMin(), types.Float32, types.Float32 + ) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'argMin') + self.assertEqual(len(coltype.nested_types), 2) + self.assertIsInstance(coltype.nested_types[0], types.Float32) + self.assertIsInstance(coltype.nested_types[1], types.Float32) + + coltype = self._type_round_trip( + types.AggregateFunction( + 'sum', types.Decimal(18, 2) + ) + )[0]['type'] + self.assertIsInstance(coltype, types.AggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.Decimal) + self.assertEqual(coltype.nested_types[0].precision, 18) + self.assertEqual(coltype.nested_types[0].scale, 2) + + @require_server_version(22, 8, 21) + def test_simple_aggregate_function(self): + coltype = self._type_round_trip( + types.SimpleAggregateFunction(func.sum(), types.UInt64) + )[0]['type'] + + self.assertIsInstance(coltype, types.SimpleAggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.UInt64) + + coltype = self._type_round_trip( + types.SimpleAggregateFunction( + 'sum', types.Float64 + ) + )[0]['type'] + self.assertIsInstance(coltype, types.SimpleAggregateFunction) + self.assertEqual(coltype.agg_func, 'sum') + self.assertEqual(len(coltype.nested_types), 1) + self.assertIsInstance(coltype.nested_types[0], types.Float64) diff --git a/tests/testcase.py b/tests/testcase.py index f44c9025..0d571aa8 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -1,5 +1,5 @@ import re -from contextlib import contextmanager, asynccontextmanager +from contextlib import contextmanager from unittest import TestCase from sqlalchemy import MetaData, text @@ -97,46 +97,29 @@ class BaseAsynchTestCase(BaseTestCase): session = asynch_session @classmethod - @run_async - async def setUpClass(cls): + def setUpClass(cls): # System database is always present. - await system_asynch_session.execute( + run_async(system_asynch_session.execute)( text('DROP DATABASE IF EXISTS {}'.format(cls.database)) ) - await system_asynch_session.execute( + run_async(system_asynch_session.execute)( text('CREATE DATABASE {}'.format(cls.database)) ) version = ( - await system_asynch_session.execute(text('SELECT version()')) + run_async(system_asynch_session.execute)(text('SELECT version()')) ).fetchall() cls.server_version = tuple(int(x) for x in version[0][0].split('.')) - super(BaseTestCase, cls).setUpClass() - - @asynccontextmanager - async def create_table(self, table): - metadata = self.metadata() - await self.run_sync(metadata.drop_all) - await self.run_sync(metadata.create_all) - - try: - yield - finally: - await self.run_sync(metadata.drop_all) + def setUp(self): + self.connection = run_async(self.session.connection)() + super(BaseAsynchTestCase, self).setUp() - async def get_connection(self): - return await self.session.connection() + def _callTestMethod(self, method): + return run_async(method)() async def run_sync(self, f): - conn = await self.get_connection() - return await conn.run_sync(f) - - async def session_scalar(self, statement): - def wrapper(session): - return session.query(statement).scalar() - - return await self.session.run_sync(wrapper) + return await self.connection.run_sync(f) class HttpSessionTestCase(BaseTestCase): diff --git a/tests/types/test_date32.py b/tests/types/test_date32.py new file mode 100644 index 00000000..b001394a --- /dev/null +++ b/tests/types/test_date32.py @@ -0,0 +1,46 @@ +import datetime + +from sqlalchemy import Column +from sqlalchemy.sql.ddl import CreateTable + +from clickhouse_sqlalchemy import types, engines, Table +from tests.testcase import BaseTestCase, CompilationTestCase +from tests.util import with_native_and_http_sessions + + +class Date32CompilationTestCase(CompilationTestCase): + required_server_version = (21, 9, 0) + + def test_create_table(self): + table = Table( + 'test', CompilationTestCase.metadata(), + Column('x', types.Date32, primary_key=True), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE test (x Date32) ENGINE = Memory' + ) + + +@with_native_and_http_sessions +class Date32TestCase(BaseTestCase): + required_server_version = (21, 9, 0) + + table = Table( + 'test', BaseTestCase.metadata(), + Column('x', types.Date32, primary_key=True), + engines.Memory() + ) + + def test_select_insert(self): + # Use a date before epoch to validate dates before epoch can be stored. + date = datetime.date(1925, 1, 1) + with self.create_table(self.table): + self.session.execute(self.table.insert(), [{'x': date}]) + result = self.session.execute(self.table.select()).scalar() + if isinstance(result, datetime.date): + self.assertEqual(result, date) + else: + self.assertEqual(result, date.isoformat()) diff --git a/tests/types/test_json.py b/tests/types/test_json.py new file mode 100644 index 00000000..b2d03858 --- /dev/null +++ b/tests/types/test_json.py @@ -0,0 +1,61 @@ +import json +from sqlalchemy import Column, text, inspect, func +from sqlalchemy.sql.ddl import CreateTable + +from clickhouse_sqlalchemy import types, engines, Table +from tests.testcase import BaseTestCase, CompilationTestCase +from tests.util import class_name_func +from parameterized import parameterized_class +from tests.session import native_session + + +class JSONCompilationTestCase(CompilationTestCase): + def test_create_table(self): + table = Table( + 'test', CompilationTestCase.metadata(), + Column('x', types.JSON), + engines.Memory() + ) + + self.assertEqual( + self.compile(CreateTable(table)), + 'CREATE TABLE test (x JSON) ENGINE = Memory' + ) + + +@parameterized_class( + [{'session': native_session}], + class_name_func=class_name_func +) +class JSONTestCase(BaseTestCase): + required_server_version = (22, 6, 1) + + table = Table( + 'test', BaseTestCase.metadata(), + Column('x', types.JSON), + engines.Memory() + ) + + def test_select_insert(self): + data = {'k1': 1, 'k2': '2', 'k3': True} + + self.table.drop(bind=self.session.bind, if_exists=True) + try: + # http session is unsupport + self.session.execute( + text('SET allow_experimental_object_type = 1;') + ) + self.session.execute(text(self.compile(CreateTable(self.table)))) + self.session.execute(self.table.insert(), [{'x': data}]) + coltype = inspect(self.session.bind).get_columns('test')[0]['type'] + self.assertIsInstance(coltype, types.JSON) + # https://clickhouse.com/docs/en/sql-reference/functions/json-functions#tojsonstring + # The json type returns a tuple of values by default, + # which needs to be converted to json using the + # toJSONString function. + res = self.session.query( + func.toJSONString(self.table.c.x) + ).scalar() + self.assertEqual(json.loads(res), data) + finally: + self.table.drop(bind=self.session.bind, if_exists=True) diff --git a/testsrequire.py b/testsrequire.py index baa9d53a..10c1ffe9 100644 --- a/testsrequire.py +++ b/testsrequire.py @@ -1,6 +1,7 @@ tests_require = [ 'pytest', + 'pytest-asyncio', 'sqlalchemy>=2.0.0,<2.1.0', 'greenlet>=2.0.1', 'alembic',