Skip to content

Commit

Permalink
Use set rather than list for hard-coded agg logic
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 committed May 29, 2024
1 parent 5d1a779 commit 3819b99
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ private class Ctx
private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisitor<PartiqlAst.PartiqlAstNode, Ctx>() {

private val pig = PartiqlAst.BUILDER()
// Currently hard-coded in legacy code
private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every")

override fun defaultReturn(node: AstNode, ctx: Ctx): Nothing {
val fromClass = node::class.qualifiedName
Expand Down Expand Up @@ -331,7 +333,7 @@ private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisi
private fun String.isAggregateCall(): Boolean {
// like PartiQLPigVisitor, keep legacy behavior the same as before
// since it is legacy, hard-coded aggregation logic
return listOf("count", "avg", "sum", "min", "max", "any", "some", "every").contains(this)
return aggregates.contains(this)
}

override fun visitExprUnary(node: Expr.Unary, ctx: Ctx) = translate(node) { metas ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,9 @@ internal object RelConverter {
* Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`.
*/
private object AggregationTransform : AstRewriter<AggregationTransform.Context>() {
// currently hard-coded
@JvmStatic
private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every")

private data class Context(
val aggregations: MutableList<Expr.Call>,
Expand Down Expand Up @@ -634,8 +637,7 @@ internal object RelConverter {
}

private fun String.isAggregateCall(): Boolean {
// currently hard-coded
return listOf("count", "avg", "sum", "min", "max", "any", "some", "every").contains(this)
return aggregates.contains(this)
}

private fun Identifier.isAggregateCall(): Boolean {
Expand Down

0 comments on commit 3819b99

Please sign in to comment.