Skip to content

Commit

Permalink
Fix custom select statement
Browse files Browse the repository at this point in the history
  • Loading branch information
aronbierbaum committed Mar 22, 2024
1 parent d332ccb commit 5f6ea05
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
28 changes: 28 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@


class ClickHouseSQLCompiler(compiler.SQLCompiler):
CUSTOM_SELECT_ATTRS = [
'_with_cube', '_with_rollup', '_with_totals', '_final_clause',
'_sample_clause', '_limit_by_clause', '_array_join'
]

def visit_select(
self,
select_stmt,
**kwargs,
):
orig_compile_state_factory = select_stmt._compile_state_factory

def compile_state_factory(self, *args, **kwargs):
result = orig_compile_state_factory(self, *args, **kwargs)

if hasattr(result, 'select_statement'):
# Fix missed attributes
for attr in ClickHouseSQLCompiler.CUSTOM_SELECT_ATTRS:
val = getattr(result.select_statement, attr, None)

if val is not None:
setattr(result.statement, attr, val)

return result

select_stmt._compile_state_factory = compile_state_factory
return super().visit_select(select_stmt=select_stmt, **kwargs)

def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + ' %% ' + \
self.process(binary.right, **kw)
Expand Down
16 changes: 16 additions & 0 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ class Query(BaseQuery):
_limit_by = None
_array_join = None

def _statement_20(self, for_statement=False, use_legacy_query_style=True):
orig_smt = super(Query, self)._statement_20(
for_statement=for_statement,
use_legacy_query_style=use_legacy_query_style
)

orig_smt._with_cube = self._with_cube
orig_smt._with_rollup = self._with_rollup
orig_smt._with_totals = self._with_totals
orig_smt._final_clause = self._final
orig_smt._sample_clause = sample_clause(self._sample)
orig_smt._limit_by_clause = self._limit_by
orig_smt._array_join = self._array_join

return orig_smt

def _compile_context(self, *args, **kwargs):
context = super(Query, self)._compile_context(*args, **kwargs)
query = context.query
Expand Down

0 comments on commit 5f6ea05

Please sign in to comment.