From 407bbc32a2e4b1c8c8d47caedce884d56ee4ca71 Mon Sep 17 00:00:00 2001 From: Aron Bierbaum Date: Mon, 18 Mar 2024 15:35:36 -0500 Subject: [PATCH] Fix custom select statement https://github.com/xzkostyan/clickhouse-sqlalchemy/pull/233 --- .../drivers/compilers/sqlcompiler.py | 28 +++++++++++++++++++ clickhouse_sqlalchemy/orm/query.py | 16 +++++++++++ 2 files changed, 44 insertions(+) diff --git a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py index 45c8fe49..9d385db3 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py @@ -10,6 +10,34 @@ class ClickHouseSQLCompiler(compiler.SQLCompiler): + CUSTOM_SELECT_ATTRS = [ + '_with_cube', '_with_rollup', '_with_totals', '_final_clause', + '_sample_clause', '_limit_by_clause', '_array_join' + ] + + def visit_select( + self, + select_stmt, + **kwargs, + ): + orig_compile_state_factory = select_stmt._compile_state_factory + + def compile_state_factory(self, *args, **kwargs): + result = orig_compile_state_factory(self, *args, **kwargs) + + if hasattr(result, 'select_statement'): + # Fix missed attributes + for attr in ClickHouseSQLCompiler.CUSTOM_SELECT_ATTRS: + val = getattr(result.select_statement, attr, None) + + if val is not None: + setattr(result.statement, attr, val) + + return result + + select_stmt._compile_state_factory = compile_state_factory + return super().visit_select(select_stmt=select_stmt, **kwargs) + def visit_mod_binary(self, binary, operator, **kw): return self.process(binary.left, **kw) + ' %% ' + \ self.process(binary.right, **kw) diff --git a/clickhouse_sqlalchemy/orm/query.py b/clickhouse_sqlalchemy/orm/query.py index 1b72e947..f1a6d5ea 100644 --- a/clickhouse_sqlalchemy/orm/query.py +++ b/clickhouse_sqlalchemy/orm/query.py @@ -19,6 +19,22 @@ class Query(BaseQuery): _limit_by = None _array_join = None + def _statement_20(self, for_statement=False, use_legacy_query_style=True): + orig_smt = super(Query, self)._statement_20( + for_statement=for_statement, + use_legacy_query_style=use_legacy_query_style + ) + + orig_smt._with_cube = self._with_cube + orig_smt._with_rollup = self._with_rollup + orig_smt._with_totals = self._with_totals + orig_smt._final_clause = self._final + orig_smt._sample_clause = sample_clause(self._sample) + orig_smt._limit_by_clause = self._limit_by + orig_smt._array_join = self._array_join + + return orig_smt + def _compile_context(self, *args, **kwargs): context = super(Query, self)._compile_context(*args, **kwargs) query = context.query