From 5d1a77925cb01ecd23c3658459d877c910894d60 Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Thu, 16 May 2024 11:08:40 -0700 Subject: [PATCH] Remove hard-coded aggregations from parser and ast --- partiql-ast/api/partiql-ast.api | 69 ++++--------------- .../org/partiql/ast/helpers/ToLegacyAst.kt | 39 ++++++----- .../kotlin/org/partiql/ast/sql/SqlDialect.kt | 11 +-- .../ast/sql/internal/InternalSqlDialect.kt | 13 +--- .../src/main/resources/partiql_ast.ion | 10 +-- .../partiql/ast/helpers/ToLegacyAstTest.kt | 11 +-- .../org/partiql/ast/sql/SqlDialectTest.kt | 16 ++--- .../lang/syntax/impl/PartiQLPigVisitor.kt | 56 +++++++++------ partiql-parser/src/main/antlr/PartiQL.g4 | 11 ++- .../parser/internal/PartiQLParserDefault.kt | 26 ++----- .../PartiQLParserFunctionCallTests.kt | 24 ++++--- .../org/partiql/planner/internal/Env.kt | 2 + .../planner/internal/PathResolverAgg.kt | 13 ---- .../internal/transforms/NormalizeSelect.kt | 6 +- .../internal/transforms/RelConverter.kt | 57 +++++++++++---- partiql-spi/api/partiql-spi.api | 10 +-- .../partiql/spi/connector/sql/SqlBuiltins.kt | 3 +- .../spi/connector/sql/builtins/AggCount.kt | 2 +- .../connector/sql/builtins/AggCountStar.kt | 25 ------- 19 files changed, 169 insertions(+), 235 deletions(-) delete mode 100644 partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt diff --git a/partiql-ast/api/partiql-ast.api b/partiql-ast/api/partiql-ast.api index 0c2783a126..688c53ecf2 100644 --- a/partiql-ast/api/partiql-ast.api +++ b/partiql-ast/api/partiql-ast.api @@ -14,11 +14,10 @@ public final class org/partiql/ast/Ast { public static final fun excludeStepCollWildcard ()Lorg/partiql/ast/Exclude$Step$CollWildcard; public static final fun excludeStepStructField (Lorg/partiql/ast/Identifier$Symbol;)Lorg/partiql/ast/Exclude$Step$StructField; public static final fun excludeStepStructWildcard ()Lorg/partiql/ast/Exclude$Step$StructWildcard; - public static final fun exprAgg (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Expr$Agg; public static final fun exprBagOp (Lorg/partiql/ast/SetOp;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;)Lorg/partiql/ast/Expr$BagOp; public static final fun exprBetween (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;)Lorg/partiql/ast/Expr$Between; public static final fun exprBinary (Lorg/partiql/ast/Expr$Binary$Op;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;)Lorg/partiql/ast/Expr$Binary; - public static final fun exprCall (Lorg/partiql/ast/Identifier;Ljava/util/List;)Lorg/partiql/ast/Expr$Call; + public static final fun exprCall (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Expr$Call; public static final fun exprCanCast (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Type;)Lorg/partiql/ast/Expr$CanCast; public static final fun exprCanLosslessCast (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Type;)Lorg/partiql/ast/Expr$CanLosslessCast; public static final fun exprCase (Lorg/partiql/ast/Expr;Ljava/util/List;Lorg/partiql/ast/Expr;)Lorg/partiql/ast/Expr$Case; @@ -516,29 +515,6 @@ public abstract class org/partiql/ast/Expr : org/partiql/ast/AstNode { public fun accept (Lorg/partiql/ast/visitor/AstVisitor;Ljava/lang/Object;)Ljava/lang/Object; } -public final class org/partiql/ast/Expr$Agg : org/partiql/ast/Expr { - public static final field Companion Lorg/partiql/ast/Expr$Agg$Companion; - public final field args Ljava/util/List; - public final field function Lorg/partiql/ast/Identifier; - public final field setq Lorg/partiql/ast/SetQuantifier; - public fun (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)V - public fun accept (Lorg/partiql/ast/visitor/AstVisitor;Ljava/lang/Object;)Ljava/lang/Object; - public static final fun builder ()Lorg/partiql/ast/builder/ExprAggBuilder; - public final fun component1 ()Lorg/partiql/ast/Identifier; - public final fun component2 ()Ljava/util/List; - public final fun component3 ()Lorg/partiql/ast/SetQuantifier; - public final fun copy (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Expr$Agg; - public static synthetic fun copy$default (Lorg/partiql/ast/Expr$Agg;Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Agg; - public fun equals (Ljava/lang/Object;)Z - public fun getChildren ()Ljava/util/List; - public fun hashCode ()I - public fun toString ()Ljava/lang/String; -} - -public final class org/partiql/ast/Expr$Agg$Companion { - public final fun builder ()Lorg/partiql/ast/builder/ExprAggBuilder; -} - public final class org/partiql/ast/Expr$BagOp : org/partiql/ast/Expr { public static final field Companion Lorg/partiql/ast/Expr$BagOp$Companion; public final field lhs Lorg/partiql/ast/Expr; @@ -636,13 +612,15 @@ public final class org/partiql/ast/Expr$Call : org/partiql/ast/Expr { public static final field Companion Lorg/partiql/ast/Expr$Call$Companion; public final field args Ljava/util/List; public final field function Lorg/partiql/ast/Identifier; - public fun (Lorg/partiql/ast/Identifier;Ljava/util/List;)V + public final field setq Lorg/partiql/ast/SetQuantifier; + public fun (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)V public fun accept (Lorg/partiql/ast/visitor/AstVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static final fun builder ()Lorg/partiql/ast/builder/ExprCallBuilder; public final fun component1 ()Lorg/partiql/ast/Identifier; public final fun component2 ()Ljava/util/List; - public final fun copy (Lorg/partiql/ast/Identifier;Ljava/util/List;)Lorg/partiql/ast/Expr$Call; - public static synthetic fun copy$default (Lorg/partiql/ast/Expr$Call;Lorg/partiql/ast/Identifier;Ljava/util/List;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Call; + public final fun component3 ()Lorg/partiql/ast/SetQuantifier; + public final fun copy (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Expr$Call; + public static synthetic fun copy$default (Lorg/partiql/ast/Expr$Call;Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Call; public fun equals (Ljava/lang/Object;)Z public fun getChildren ()Ljava/util/List; public fun hashCode ()I @@ -4028,16 +4006,14 @@ public final class org/partiql/ast/builder/AstBuilder { public static synthetic fun excludeStepStructField$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Identifier$Symbol;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Exclude$Step$StructField; public final fun excludeStepStructWildcard (Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Exclude$Step$StructWildcard; public static synthetic fun excludeStepStructWildcard$default (Lorg/partiql/ast/builder/AstBuilder;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Exclude$Step$StructWildcard; - public final fun exprAgg (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$Agg; - public static synthetic fun exprAgg$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Agg; public final fun exprBagOp (Lorg/partiql/ast/SetOp;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$BagOp; public static synthetic fun exprBagOp$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/SetOp;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$BagOp; public final fun exprBetween (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$Between; public static synthetic fun exprBetween$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Between; public final fun exprBinary (Lorg/partiql/ast/Expr$Binary$Op;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$Binary; public static synthetic fun exprBinary$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Expr$Binary$Op;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Binary; - public final fun exprCall (Lorg/partiql/ast/Identifier;Ljava/util/List;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$Call; - public static synthetic fun exprCall$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Identifier;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Call; + public final fun exprCall (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$Call; + public static synthetic fun exprCall$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$Call; public final fun exprCanCast (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Type;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$CanCast; public static synthetic fun exprCanCast$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Type;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Expr$CanCast; public final fun exprCanLosslessCast (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Type;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Expr$CanLosslessCast; @@ -4497,22 +4473,6 @@ public final class org/partiql/ast/builder/ExcludeStepStructWildcardBuilder { public final fun build ()Lorg/partiql/ast/Exclude$Step$StructWildcard; } -public final class org/partiql/ast/builder/ExprAggBuilder { - public fun ()V - public fun (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)V - public synthetic fun (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun args (Ljava/util/List;)Lorg/partiql/ast/builder/ExprAggBuilder; - public final fun build ()Lorg/partiql/ast/Expr$Agg; - public final fun function (Lorg/partiql/ast/Identifier;)Lorg/partiql/ast/builder/ExprAggBuilder; - public final fun getArgs ()Ljava/util/List; - public final fun getFunction ()Lorg/partiql/ast/Identifier; - public final fun getSetq ()Lorg/partiql/ast/SetQuantifier; - public final fun setArgs (Ljava/util/List;)V - public final fun setFunction (Lorg/partiql/ast/Identifier;)V - public final fun setSetq (Lorg/partiql/ast/SetQuantifier;)V - public final fun setq (Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/builder/ExprAggBuilder; -} - public final class org/partiql/ast/builder/ExprBagOpBuilder { public fun ()V public fun (Lorg/partiql/ast/SetOp;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Expr;Ljava/lang/Boolean;)V @@ -4569,15 +4529,18 @@ public final class org/partiql/ast/builder/ExprBinaryBuilder { public final class org/partiql/ast/builder/ExprCallBuilder { public fun ()V - public fun (Lorg/partiql/ast/Identifier;Ljava/util/List;)V - public synthetic fun (Lorg/partiql/ast/Identifier;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;)V + public synthetic fun (Lorg/partiql/ast/Identifier;Ljava/util/List;Lorg/partiql/ast/SetQuantifier;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun args (Ljava/util/List;)Lorg/partiql/ast/builder/ExprCallBuilder; public final fun build ()Lorg/partiql/ast/Expr$Call; public final fun function (Lorg/partiql/ast/Identifier;)Lorg/partiql/ast/builder/ExprCallBuilder; public final fun getArgs ()Ljava/util/List; public final fun getFunction ()Lorg/partiql/ast/Identifier; + public final fun getSetq ()Lorg/partiql/ast/SetQuantifier; public final fun setArgs (Ljava/util/List;)V public final fun setFunction (Lorg/partiql/ast/Identifier;)V + public final fun setSetq (Lorg/partiql/ast/SetQuantifier;)V + public final fun setq (Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/builder/ExprCallBuilder; } public final class org/partiql/ast/builder/ExprCanCastBuilder { @@ -6406,8 +6369,6 @@ public abstract class org/partiql/ast/sql/SqlDialect : org/partiql/ast/visitor/A public fun visitExcludeStepStructField (Lorg/partiql/ast/Exclude$Step$StructField;Lorg/partiql/ast/sql/SqlBlock;)Lorg/partiql/ast/sql/SqlBlock; public synthetic fun visitExcludeStepStructWildcard (Lorg/partiql/ast/Exclude$Step$StructWildcard;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExcludeStepStructWildcard (Lorg/partiql/ast/Exclude$Step$StructWildcard;Lorg/partiql/ast/sql/SqlBlock;)Lorg/partiql/ast/sql/SqlBlock; - public synthetic fun visitExprAgg (Lorg/partiql/ast/Expr$Agg;Ljava/lang/Object;)Ljava/lang/Object; - public fun visitExprAgg (Lorg/partiql/ast/Expr$Agg;Lorg/partiql/ast/sql/SqlBlock;)Lorg/partiql/ast/sql/SqlBlock; public synthetic fun visitExprBagOp (Lorg/partiql/ast/Expr$BagOp;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExprBagOp (Lorg/partiql/ast/Expr$BagOp;Lorg/partiql/ast/sql/SqlBlock;)Lorg/partiql/ast/sql/SqlBlock; public synthetic fun visitExprBetween (Lorg/partiql/ast/Expr$Between;Ljava/lang/Object;)Ljava/lang/Object; @@ -6681,8 +6642,6 @@ public abstract class org/partiql/ast/util/AstRewriter : org/partiql/ast/visitor public fun visitExcludeStepStructField (Lorg/partiql/ast/Exclude$Step$StructField;Ljava/lang/Object;)Lorg/partiql/ast/AstNode; public synthetic fun visitExcludeStepStructWildcard (Lorg/partiql/ast/Exclude$Step$StructWildcard;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExcludeStepStructWildcard (Lorg/partiql/ast/Exclude$Step$StructWildcard;Ljava/lang/Object;)Lorg/partiql/ast/AstNode; - public synthetic fun visitExprAgg (Lorg/partiql/ast/Expr$Agg;Ljava/lang/Object;)Ljava/lang/Object; - public fun visitExprAgg (Lorg/partiql/ast/Expr$Agg;Ljava/lang/Object;)Lorg/partiql/ast/AstNode; public synthetic fun visitExprBagOp (Lorg/partiql/ast/Expr$BagOp;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExprBagOp (Lorg/partiql/ast/Expr$BagOp;Ljava/lang/Object;)Lorg/partiql/ast/AstNode; public synthetic fun visitExprBetween (Lorg/partiql/ast/Expr$Between;Ljava/lang/Object;)Ljava/lang/Object; @@ -7013,7 +6972,6 @@ public abstract class org/partiql/ast/visitor/AstBaseVisitor : org/partiql/ast/v public fun visitExcludeStepStructField (Lorg/partiql/ast/Exclude$Step$StructField;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExcludeStepStructWildcard (Lorg/partiql/ast/Exclude$Step$StructWildcard;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExpr (Lorg/partiql/ast/Expr;Ljava/lang/Object;)Ljava/lang/Object; - public fun visitExprAgg (Lorg/partiql/ast/Expr$Agg;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExprBagOp (Lorg/partiql/ast/Expr$BagOp;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExprBetween (Lorg/partiql/ast/Expr$Between;Ljava/lang/Object;)Ljava/lang/Object; public fun visitExprBinary (Lorg/partiql/ast/Expr$Binary;Ljava/lang/Object;)Ljava/lang/Object; @@ -7207,7 +7165,6 @@ public abstract interface class org/partiql/ast/visitor/AstVisitor { public abstract fun visitExcludeStepStructField (Lorg/partiql/ast/Exclude$Step$StructField;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun visitExcludeStepStructWildcard (Lorg/partiql/ast/Exclude$Step$StructWildcard;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun visitExpr (Lorg/partiql/ast/Expr;Ljava/lang/Object;)Ljava/lang/Object; - public abstract fun visitExprAgg (Lorg/partiql/ast/Expr$Agg;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun visitExprBagOp (Lorg/partiql/ast/Expr$BagOp;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun visitExprBetween (Lorg/partiql/ast/Expr$Between;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun visitExprBinary (Lorg/partiql/ast/Expr$Binary;Ljava/lang/Object;)Ljava/lang/Object; diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index d2a07a7982..fd7fdb2a90 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -309,26 +309,29 @@ private class AstTranslator(val metas: Map) : AstBaseVisi } val funcName = (node.function as Identifier.Symbol).symbol.lowercase() val args = node.args.translate(ctx) - call(funcName, args, metas) + when (funcName.isAggregateCall() || node.setq != null) { // Use existing assumption that function call with set quantifier is an aggregation function + true -> { + val setq = node.setq?.toLegacySetQuantifier() ?: all() + // COUNT(*) is represented as COUNT() in default AST + // Legacy AST translates COUNT(*) to COUNT(1) + if (funcName == "count" && args.isEmpty()) { + return callAgg(setq, "count", lit(ionInt(1)), metas) + } + // Default Case + if (node.args.size != 1) { + error("Cannot translate `call_agg` with more than one argument") + } + val arg = visitExpr(node.args[0], ctx) + callAgg(setq, funcName, arg, metas) + } + false -> call(funcName, args, metas) + } } - override fun visitExprAgg(node: Expr.Agg, ctx: Ctx) = translate(node) { metas -> - val setq = node.setq?.toLegacySetQuantifier() ?: all() - // Legacy AST translates COUNT(*) to COUNT(1) - if (node.function is Identifier.Symbol && (node.function as Identifier.Symbol).symbol == "COUNT_STAR") { - return callAgg(setq, "count", lit(ionInt(1)), metas) - } - // Default Case - if (node.args.size != 1) { - error("Legacy `call_agg` must have exactly one argument") - } - if (node.function is Identifier.Qualified) { - error("Qualified identifiers are not allowed in legacy AST `call_agg` function identifiers") - } - // Legacy parser/ast always inserts ALL quantifier - val funcName = (node.function as Identifier.Symbol).symbol.lowercase() - val arg = visitExpr(node.args[0], ctx) - callAgg(setq, funcName, arg, metas) + private fun String.isAggregateCall(): Boolean { + // like PartiQLPigVisitor, keep legacy behavior the same as before + // since it is legacy, hard-coded aggregation logic + return listOf("count", "avg", "sum", "min", "max", "any", "some", "every").contains(this) } override fun visitExprUnary(node: Expr.Unary, ctx: Ctx) = translate(node) { metas -> diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt index 6a589a0556..97a863566e 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt @@ -304,17 +304,10 @@ public abstract class SqlDialect : AstBaseVisitor() { override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, head: SqlBlock): SqlBlock = head concat r(".*") override fun visitExprCall(node: Expr.Call, head: SqlBlock): SqlBlock { - var h = head - h = visitIdentifier(node.function, h) - h = h concat list { node.args } - return h - } - - override fun visitExprAgg(node: Expr.Agg, head: SqlBlock): SqlBlock { var h = head val f = node.function - // Special case - if (f is Identifier.Symbol && f.symbol == "COUNT_STAR") { + // Special case -- COUNT() maps to COUNT(*) + if (f is Identifier.Symbol && f.symbol == "COUNT" && node.args.isEmpty()) { return h concat r("COUNT(*)") } val start = if (node.setq != null) "(${node.setq.name} " else "(" diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt index b01b56e680..6c93aa009b 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt @@ -330,20 +330,13 @@ internal abstract class InternalSqlDialect : AstBaseVisitor ::= functionCall - : qualifiedName PAREN_LEFT ( expr ( COMMA expr )* )? PAREN_RIGHT + : qualifiedName PAREN_LEFT ASTERISK PAREN_RIGHT + | qualifiedName PAREN_LEFT ( setQuantifierStrategy? expr ( COMMA expr )* )? PAREN_RIGHT ; pathStep @@ -831,6 +826,8 @@ nonReserved /* PartiQL */ | EXCLUDED | EXISTS | SIZE + /* Other words not in above */ + | ANY | SOME ; /** diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index 27a1b39da4..8dc143a3b2 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -68,7 +68,6 @@ import org.partiql.ast.excludeStepCollIndex import org.partiql.ast.excludeStepCollWildcard import org.partiql.ast.excludeStepStructField import org.partiql.ast.excludeStepStructWildcard -import org.partiql.ast.exprAgg import org.partiql.ast.exprBagOp import org.partiql.ast.exprBetween import org.partiql.ast.exprBinary @@ -1865,11 +1864,11 @@ internal class PartiQLParserDefault : PartiQLParser { val path = ctx.qualifiedName().qualifier.map { visitSymbolPrimitive(it) } val name = identifierSymbol("char_length", Identifier.CaseSensitivity.INSENSITIVE) if (path.isEmpty()) { - exprCall(name, args) + exprCall(name, args, setq = null) // setq = null for scalar fn } else { val root = path.first() val steps = path.drop(1) + listOf(name) - exprCall(identifierQualified(root, steps), args) + exprCall(identifierQualified(root, steps), args, setq = null) } } else -> visitNonReservedFunctionCall(ctx, args) @@ -1880,7 +1879,7 @@ internal class PartiQLParserDefault : PartiQLParser { } private fun visitNonReservedFunctionCall(ctx: GeneratedParser.FunctionCallContext, args: List): Expr.Call { val function = visitQualifiedName(ctx.qualifiedName()) - return exprCall(function, args) + return exprCall(function, args, convertSetQuantifier(ctx.setQuantifierStrategy())) } /** @@ -1912,7 +1911,7 @@ internal class PartiQLParserDefault : PartiQLParser { // normal form val function = "SUBSTRING".toIdentifier() val args = visitOrEmpty(ctx.expr()) - exprCall(function, args) + exprCall(function, args, setq = null) // setq = null for scalar fn } else { // special form val value = visitExpr(ctx.expr(0)) @@ -1930,7 +1929,7 @@ internal class PartiQLParserDefault : PartiQLParser { // normal form val function = "POSITION".toIdentifier() val args = visitOrEmpty(ctx.expr()) - exprCall(function, args) + exprCall(function, args, setq = null) // setq = null for scalar fn } else { // special form val lhs = visitExpr(ctx.expr(0)) @@ -1965,14 +1964,6 @@ internal class PartiQLParserDefault : PartiQLParser { } } - /** - * COUNT(*) - */ - override fun visitCountAll(ctx: GeneratedParser.CountAllContext) = translate(ctx) { - val function = "COUNT_STAR".toIdentifier() - exprAgg(function, emptyList(), SetQuantifier.ALL) - } - override fun visitExtract(ctx: GeneratedParser.ExtractContext) = translate(ctx) { val field = try { DatetimeField.valueOf(ctx.IDENTIFIER().text.uppercase()) @@ -1999,13 +1990,6 @@ internal class PartiQLParserDefault : PartiQLParser { exprTrim(value, chars, spec) } - override fun visitAggregateBase(ctx: GeneratedParser.AggregateBaseContext) = translate(ctx) { - val function = ctx.func.text.toIdentifier() - val args = listOf(visitExpr(ctx.expr())) - val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) - exprAgg(function, args, setq) - } - /** * Window Functions */ diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt index a562805895..a35372845c 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt @@ -22,7 +22,8 @@ class PartiQLParserFunctionCallTests { query { exprCall( function = identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -33,7 +34,8 @@ class PartiQLParserFunctionCallTests { query { exprCall( function = identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -44,7 +46,8 @@ class PartiQLParserFunctionCallTests { query { exprCall( function = identifierSymbol("upper", Identifier.CaseSensitivity.INSENSITIVE), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -55,7 +58,8 @@ class PartiQLParserFunctionCallTests { query { exprCall( function = identifierSymbol("upper", Identifier.CaseSensitivity.SENSITIVE), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -72,7 +76,8 @@ class PartiQLParserFunctionCallTests { identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), ) ), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -89,7 +94,8 @@ class PartiQLParserFunctionCallTests { identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), ) ), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -106,7 +112,8 @@ class PartiQLParserFunctionCallTests { identifierSymbol("upper", Identifier.CaseSensitivity.INSENSITIVE), ) ), - args = emptyList() + args = emptyList(), + setq = null ) } ) @@ -123,7 +130,8 @@ class PartiQLParserFunctionCallTests { identifierSymbol("upper", Identifier.CaseSensitivity.SENSITIVE), ) ), - args = emptyList() + args = emptyList(), + setq = null ) } ) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 664b321122..478ca2f3a5 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -34,6 +34,8 @@ import org.partiql.value.PartiQLValueType * * See TypeEnv for the variables type environment. * + * TODO: function resolution between scalar functions and aggregations. + * * @property session */ internal class Env(private val session: PartiQLPlanner.Session) { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt index 7a19e15ede..c976626784 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt @@ -8,19 +8,6 @@ import org.partiql.spi.connector.ConnectorHandle import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.fn.FnExperimental -/** - * Today, all aggregations are hard-coded into the grammar. We cannot implement user-defined aggregations until - * the grammar and AST are updated appropriately. We should not have an aggregation node in the AST, just a call node. - * During planning, we would then check if a call is an aggregation and translate the AST to the appropriate algebra. - * - * PartiQL.g4 - * - * aggregate - * : func=COUNT PAREN_LEFT ASTERISK PAREN_RIGHT - * | func=(COUNT|MAX|MIN|SUM|AVG|EVERY|ANY|SOME) PAREN_LEFT setQuantifierStrategy? expr PAREN_RIGHT - * ; - * - */ @OptIn(FnExperimental::class) internal class PathResolverAgg( catalog: ConnectorMetadata, diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt index 0c4e1f068c..cd2649dd64 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt @@ -216,7 +216,8 @@ internal object NormalizeSelect { return selectValue( constructor = exprCall( function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE), - args = tupleUnionArgs + args = tupleUnionArgs, + setq = null // setq = null for scalar fn ), setq = select.setq ) @@ -256,7 +257,8 @@ internal object NormalizeSelect { setq = node.setq, constructor = exprCall( function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE), - args = tupleUnionArgs + args = tupleUnionArgs, + setq = null // setq = null for scalar fn ) ) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index 21259c3bed..f45fef89ca 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -28,6 +28,7 @@ import org.partiql.ast.SetOp import org.partiql.ast.SetQuantifier import org.partiql.ast.Sort import org.partiql.ast.builder.ast +import org.partiql.ast.exprLit import org.partiql.ast.exprVar import org.partiql.ast.helpers.toBinder import org.partiql.ast.identifierSymbol @@ -71,6 +72,7 @@ import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue +import org.partiql.value.int32Value import org.partiql.value.stringValue import org.partiql.planner.internal.ir.Identifier as InternalId @@ -367,12 +369,20 @@ internal object RelConverter { is InternalId.Qualified -> error("Qualified aggregation calls are not supported.") is InternalId.Symbol -> id.symbol.lowercase() } - val setq = when (expr.setq) { - null -> Rel.Op.Aggregate.SetQuantifier.ALL - SetQuantifier.ALL -> Rel.Op.Aggregate.SetQuantifier.ALL - SetQuantifier.DISTINCT -> Rel.Op.Aggregate.SetQuantifier.DISTINCT + if (name == "count" && expr.args.isEmpty()) { + relOpAggregateCallUnresolved( + name, + Rel.Op.Aggregate.SetQuantifier.ALL, + args = listOf(exprLit(int32Value(1)).toRex(env)) + ) + } else { + val setq = when (expr.setq) { + null -> Rel.Op.Aggregate.SetQuantifier.ALL + SetQuantifier.ALL -> Rel.Op.Aggregate.SetQuantifier.ALL + SetQuantifier.DISTINCT -> Rel.Op.Aggregate.SetQuantifier.DISTINCT + } + relOpAggregateCallUnresolved(name, setq, args) } - relOpAggregateCallUnresolved(name, setq, args) }.toMutableList() // Add GROUP_AS aggregation @@ -584,12 +594,12 @@ internal object RelConverter { private object AggregationTransform : AstRewriter() { private data class Context( - val aggregations: MutableList, + val aggregations: MutableList, val keys: List ) - fun apply(node: Expr.SFW): Pair> { - val aggs = mutableListOf() + fun apply(node: Expr.SFW): Pair> { + val aggs = mutableListOf() val keys = node.groupBy?.keys ?: emptyList() val context = Context(aggs, keys) val select = super.visitExprSFW(node, context) as Expr.SFW @@ -607,13 +617,32 @@ internal object RelConverter { // only rewrite top-level SFW override fun visitExprSFW(node: Expr.SFW, ctx: Context): AstNode = node - override fun visitExprAgg(node: Expr.Agg, ctx: Context) = ast { - val id = identifierSymbol { - symbol = syntheticAgg(ctx.aggregations.size) - caseSensitivity = org.partiql.ast.Identifier.CaseSensitivity.INSENSITIVE + override fun visitExprCall(node: Expr.Call, ctx: Context) = ast { + // TODO replace w/ proper function resolution to determine whether a function call is a scalar or aggregate. + // may require further modification of SPI interfaces to support + when (node.function.isAggregateCall()) { + true -> { + val id = identifierSymbol { + symbol = syntheticAgg(ctx.aggregations.size) + caseSensitivity = org.partiql.ast.Identifier.CaseSensitivity.INSENSITIVE + } + ctx.aggregations += node + exprVar(id, Expr.Var.Scope.DEFAULT) + } + else -> node + } + } + + private fun String.isAggregateCall(): Boolean { + // currently hard-coded + return listOf("count", "avg", "sum", "min", "max", "any", "some", "every").contains(this) + } + + private fun Identifier.isAggregateCall(): Boolean { + return when (this) { + is Identifier.Symbol -> this.symbol.lowercase().isAggregateCall() + is Identifier.Qualified -> this.steps.last().symbol.lowercase().isAggregateCall() } - ctx.aggregations += node - exprVar(id, Expr.Var.Scope.DEFAULT) } } diff --git a/partiql-spi/api/partiql-spi.api b/partiql-spi/api/partiql-spi.api index 415526a69c..a5208ea8da 100644 --- a/partiql-spi/api/partiql-spi.api +++ b/partiql-spi/api/partiql-spi.api @@ -265,14 +265,8 @@ public final class org/partiql/spi/connector/sql/builtins/Agg_AVG__INT__INT : or public fun getSignature ()Lorg/partiql/spi/fn/AggSignature; } -public final class org/partiql/spi/connector/sql/builtins/Agg_COUNT_STAR____INT32 : org/partiql/spi/fn/Agg { - public static final field INSTANCE Lorg/partiql/spi/connector/sql/builtins/Agg_COUNT_STAR____INT32; - public fun accumulator ()Lorg/partiql/spi/fn/Agg$Accumulator; - public fun getSignature ()Lorg/partiql/spi/fn/AggSignature; -} - -public final class org/partiql/spi/connector/sql/builtins/Agg_COUNT__ANY__INT32 : org/partiql/spi/fn/Agg { - public static final field INSTANCE Lorg/partiql/spi/connector/sql/builtins/Agg_COUNT__ANY__INT32; +public final class org/partiql/spi/connector/sql/builtins/Agg_COUNT__ANY__INT64 : org/partiql/spi/fn/Agg { + public static final field INSTANCE Lorg/partiql/spi/connector/sql/builtins/Agg_COUNT__ANY__INT64; public fun accumulator ()Lorg/partiql/spi/fn/Agg$Accumulator; public fun getSignature ()Lorg/partiql/spi/fn/AggSignature; } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt index 5df65ed9e6..b3e305b930 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt @@ -446,8 +446,7 @@ internal object SqlBuiltins { Agg_AVG__FLOAT32__FLOAT32, Agg_AVG__FLOAT64__FLOAT64, Agg_AVG__ANY__ANY, - Agg_COUNT__ANY__INT32, - Agg_COUNT_STAR____INT32, + Agg_COUNT__ANY__INT64, Agg_EVERY__BOOL__BOOL, Agg_EVERY__ANY__BOOL, Agg_MAX__INT8__INT8, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt index 8264a5f807..63dc2be90c 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt @@ -13,7 +13,7 @@ import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.INT64 @OptIn(PartiQLValueExperimental::class, FnExperimental::class) -public object Agg_COUNT__ANY__INT32 : Agg { +public object Agg_COUNT__ANY__INT64 : Agg { override val signature: AggSignature = AggSignature( name = "count", diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt deleted file mode 100644 index cb63fc7411..0000000000 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt +++ /dev/null @@ -1,25 +0,0 @@ -// ktlint-disable filename -@file:Suppress("ClassName") - -package org.partiql.spi.connector.sql.builtins - -import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCountStar -import org.partiql.spi.fn.Agg -import org.partiql.spi.fn.AggSignature -import org.partiql.spi.fn.FnExperimental -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType.INT64 - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -public object Agg_COUNT_STAR____INT32 : Agg { - - override val signature: AggSignature = AggSignature( - name = "count_star", - returns = INT64, - parameters = listOf(), - isNullable = false, - isDecomposable = true - ) - - override fun accumulator(): Agg.Accumulator = AccumulatorCountStar() -}