diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index fd7fdb2a90..2549ea6d5c 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -74,6 +74,8 @@ private class Ctx private class AstTranslator(val metas: Map) : AstBaseVisitor() { private val pig = PartiqlAst.BUILDER() + // Currently hard-coded in legacy code + private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") override fun defaultReturn(node: AstNode, ctx: Ctx): Nothing { val fromClass = node::class.qualifiedName @@ -331,7 +333,7 @@ private class AstTranslator(val metas: Map) : AstBaseVisi private fun String.isAggregateCall(): Boolean { // like PartiQLPigVisitor, keep legacy behavior the same as before // since it is legacy, hard-coded aggregation logic - return listOf("count", "avg", "sum", "min", "max", "any", "some", "every").contains(this) + return aggregates.contains(this) } override fun visitExprUnary(node: Expr.Unary, ctx: Ctx) = translate(node) { metas -> diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index f45fef89ca..5d82f4acbc 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -592,6 +592,9 @@ internal object RelConverter { * Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`. */ private object AggregationTransform : AstRewriter() { + // currently hard-coded + @JvmStatic + private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") private data class Context( val aggregations: MutableList, @@ -634,8 +637,7 @@ internal object RelConverter { } private fun String.isAggregateCall(): Boolean { - // currently hard-coded - return listOf("count", "avg", "sum", "min", "max", "any", "some", "every").contains(this) + return aggregates.contains(this) } private fun Identifier.isAggregateCall(): Boolean {