org.apache.spark
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
index ca4ea5114c26b..c0078872bd843 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
@@ -20,8 +20,11 @@
import org.apache.spark.SparkUnsupportedOperationException;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.util.QuotingUtils;
import org.apache.spark.sql.types.DataType;
+import java.util.Map;
+
/**
* Interface for a function that produces a result value for each input row.
*
@@ -149,7 +152,10 @@ public interface ScalarFunction extends BoundFunction {
* @return a result value
*/
default R produceResult(InternalRow input) {
- throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3146");
+ throw new SparkUnsupportedOperationException(
+ "SCALAR_FUNCTION_NOT_COMPATIBLE",
+ Map.of("scalarFunc", QuotingUtils.quoteIdentifier(name()))
+ );
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b2e9115dd512f..5d41c07b47842 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1591,7 +1591,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
- p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
+ val expanded = p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
+ if (expanded.projectList.size < p.projectList.size) {
+ checkTrailingCommaInSelect(expanded, starRemoved = true)
+ }
+ expanded
// If the filter list contains Stars, expand it.
case p: Filter if containsStar(Seq(p.condition)) =>
p.copy(expandStarExpression(p.condition, p.child))
@@ -1600,7 +1604,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError()
} else {
- a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
+ val expanded = a.copy(aggregateExpressions =
+ buildExpandedProjectList(a.aggregateExpressions, a.child))
+ if (expanded.aggregateExpressions.size < a.aggregateExpressions.size) {
+ checkTrailingCommaInSelect(expanded, starRemoved = true)
+ }
+ expanded
}
case c: CollectMetrics if containsStar(c.metrics) =>
c.copy(metrics = buildExpandedProjectList(c.metrics, c.child))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index b600f455f16ac..4720b9dcdfa13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -173,6 +173,36 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
)
}
+ /**
+ * Checks for errors in a `SELECT` clause, such as a trailing comma or an empty select list.
+ *
+ * @param plan The logical plan of the query.
+ * @param starRemoved Whether a '*' (wildcard) was removed from the select list.
+ * @throws AnalysisException if the select list is empty or ends with a trailing comma.
+ */
+ protected def checkTrailingCommaInSelect(
+ plan: LogicalPlan,
+ starRemoved: Boolean = false): Unit = {
+ val exprList = plan match {
+ case proj: Project if proj.projectList.nonEmpty =>
+ proj.projectList
+ case agg: Aggregate if agg.aggregateExpressions.nonEmpty =>
+ agg.aggregateExpressions
+ case _ =>
+ Seq.empty
+ }
+
+ exprList.lastOption match {
+ case Some(Alias(UnresolvedAttribute(Seq(name)), _)) =>
+ if (name.equalsIgnoreCase("FROM") && plan.exists(_.isInstanceOf[OneRowRelation])) {
+ if (exprList.size > 1 || starRemoved) {
+ throw QueryCompilationErrors.trailingCommaInSelectError(exprList.last.origin)
+ }
+ }
+ case _ =>
+ }
+ }
+
def checkAnalysis(plan: LogicalPlan): Unit = {
// We should inline all CTE relations to restore the original plan shape, as the analysis check
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
@@ -210,6 +240,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier
write.table.tableNotFound(tblName)
+ // We should check for trailing comma errors first, since we would get less obvious
+ // unresolved column errors if we do it bottom up
+ case proj: Project =>
+ checkTrailingCommaInSelect(proj)
+ case agg: Aggregate =>
+ checkTrailingCommaInSelect(agg)
+
case _ =>
}
@@ -1584,7 +1621,7 @@ class PreemptedError() {
// errors have the lowest priority.
def set(error: Exception with SparkThrowable, priority: Option[Int] = None): Unit = {
val calculatedPriority = priority.getOrElse {
- error.getErrorClass match {
+ error.getCondition match {
case c if c.startsWith("INTERNAL_ERROR") => 1
case _ => 2
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index e22a4b941b30c..8181078c519fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -24,20 +24,12 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
-/**
- * A helper class used to detect duplicate relations fast in `DeduplicateRelations`. Two relations
- * are duplicated if:
- * 1. they are the same class.
- * 2. they have the same output attribute IDs.
- *
- * The first condition is necessary because the CTE relation definition node and reference node have
- * the same output attribute IDs but they are not duplicated.
- */
-case class RelationWrapper(cls: Class[_], outputAttrIds: Seq[Long])
-
object DeduplicateRelations extends Rule[LogicalPlan] {
+
+ type ExprIdMap = mutable.HashMap[Class[_], mutable.HashSet[Long]]
+
override def apply(plan: LogicalPlan): LogicalPlan = {
- val newPlan = renewDuplicatedRelations(mutable.HashSet.empty, plan)._1
+ val newPlan = renewDuplicatedRelations(mutable.HashMap.empty, plan)._1
// Wait for `ResolveMissingReferences` to resolve missing attributes first
def noMissingInput(p: LogicalPlan) = !p.exists(_.missingInput.nonEmpty)
@@ -86,10 +78,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
private def existDuplicatedExprId(
- existingRelations: mutable.HashSet[RelationWrapper],
- plan: RelationWrapper): Boolean = {
- existingRelations.filter(_.cls == plan.cls)
- .exists(_.outputAttrIds.intersect(plan.outputAttrIds).nonEmpty)
+ existingRelations: ExprIdMap,
+ planClass: Class[_], exprIds: Seq[Long]): Boolean = {
+ val attrSet = existingRelations.getOrElse(planClass, mutable.HashSet.empty)
+ exprIds.exists(attrSet.contains)
}
/**
@@ -100,20 +92,16 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
* whether the plan is changed or not)
*/
private def renewDuplicatedRelations(
- existingRelations: mutable.HashSet[RelationWrapper],
+ existingRelations: ExprIdMap,
plan: LogicalPlan): (LogicalPlan, Boolean) = plan match {
case p: LogicalPlan if p.isStreaming => (plan, false)
case m: MultiInstanceRelation =>
- val planWrapper = RelationWrapper(m.getClass, m.output.map(_.exprId.id))
- if (existingRelations.contains(planWrapper)) {
- val newNode = m.newInstance()
- newNode.copyTagsFrom(m)
- (newNode, true)
- } else {
- existingRelations.add(planWrapper)
- (m, false)
- }
+ deduplicateAndRenew[LogicalPlan with MultiInstanceRelation](
+ existingRelations,
+ m,
+ _.output.map(_.exprId.id),
+ node => node.newInstance().asInstanceOf[LogicalPlan with MultiInstanceRelation])
case p: Project =>
deduplicateAndRenew[Project](
@@ -207,7 +195,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
private def deduplicate(
- existingRelations: mutable.HashSet[RelationWrapper],
+ existingRelations: ExprIdMap,
plan: LogicalPlan): (LogicalPlan, Boolean) = {
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
@@ -291,20 +279,21 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
private def deduplicateAndRenew[T <: LogicalPlan](
- existingRelations: mutable.HashSet[RelationWrapper], plan: T,
+ existingRelations: ExprIdMap, plan: T,
getExprIds: T => Seq[Long],
copyNewPlan: T => T): (LogicalPlan, Boolean) = {
var (newPlan, planChanged) = deduplicate(existingRelations, plan)
if (newPlan.resolved) {
val exprIds = getExprIds(newPlan.asInstanceOf[T])
if (exprIds.nonEmpty) {
- val planWrapper = RelationWrapper(newPlan.getClass, exprIds)
- if (existDuplicatedExprId(existingRelations, planWrapper)) {
+ if (existDuplicatedExprId(existingRelations, newPlan.getClass, exprIds)) {
newPlan = copyNewPlan(newPlan.asInstanceOf[T])
newPlan.copyTagsFrom(plan)
(newPlan, true)
} else {
- existingRelations.add(planWrapper)
+ val attrSet = existingRelations.getOrElseUpdate(newPlan.getClass, mutable.HashSet.empty)
+ exprIds.foreach(attrSet.add)
+ existingRelations.put(newPlan.getClass, attrSet)
(newPlan, planChanged)
}
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
index 2642b4a1c5daa..0f9b93cc2986d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
@@ -36,7 +36,7 @@ class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
// df.drop(col("non-existing-column"))
val dropped = d.dropList.map {
case u: UnresolvedAttribute =>
- resolveExpressionByPlanChildren(u, d.child)
+ resolveExpressionByPlanChildren(u, d)
case e => e
}
val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index d7d53230470d9..f2f86a90d5172 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -87,7 +87,7 @@ object ExpressionEncoder {
}
constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
} catch {
- case e: SparkRuntimeException if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" =>
+ case e: SparkRuntimeException if e.getCondition == "NOT_NULL_ASSERT_VIOLATION" =>
throw e
case e: Exception =>
throw QueryExecutionErrors.expressionDecodingError(e, expressions)
@@ -115,7 +115,7 @@ object ExpressionEncoder {
inputRow(0) = t
extractProjection(inputRow)
} catch {
- case e: SparkRuntimeException if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" =>
+ case e: SparkRuntimeException if e.getCondition == "NOT_NULL_ASSERT_VIOLATION" =>
throw e
case e: Exception =>
throw QueryExecutionErrors.expressionEncodingError(e, expressions)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index de15ec43c4f31..6a57ba2aaa569 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -383,10 +383,10 @@ abstract class Expression extends TreeNode[Expression] {
trait FoldableUnevaluable extends Expression {
override def foldable: Boolean = true
- final override def eval(input: InternalRow = null): Any =
+ override def eval(input: InternalRow = null): Any =
throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
- final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
index 433f8500fab1f..04d31b5797819 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
@@ -17,7 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -37,8 +41,21 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
abstract class PartitionTransformExpression extends Expression with Unevaluable
with UnaryLike[Expression] {
override def nullable: Boolean = true
-}
+ override def eval(input: InternalRow): Any =
+ throw new SparkException(
+ errorClass = "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY",
+ messageParameters = Map("expression" -> toSQLExpr(this)),
+ cause = null
+ )
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new SparkException(
+ errorClass = "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY",
+ messageParameters = Map("expression" -> toSQLExpr(this)),
+ cause = null
+ )
+}
/**
* Expression for the v2 partition transform years.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index 220920a5a3198..d14c8cb675387 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -182,8 +183,8 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
ApplyFunctionExpression(scalarFunc, arguments)
case _ =>
throw new AnalysisException(
- errorClass = "_LEGACY_ERROR_TEMP_3055",
- messageParameters = Map("scalarFunc" -> scalarFunc.name()))
+ errorClass = "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED",
+ messageParameters = Map("scalarFunc" -> toSQLId(scalarFunc.name())))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index c593c8bfb8341..0a4882bfada17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.Growable
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
@@ -118,7 +118,8 @@ case class CollectList(
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
- override def prettyName: String = "collect_list"
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("collect_list")
override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
new GenericArrayData(buffer.toArray)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index b166d235557fc..764637b97a100 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -150,7 +150,8 @@ case class CurrentDate(timeZoneId: Option[String] = None)
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
- override def prettyName: String = "current_date"
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_date")
}
// scalastyle:off line.size.limit
@@ -329,7 +330,7 @@ case class DateAdd(startDate: Expression, days: Expression)
})
}
- override def prettyName: String = "date_add"
+ override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("date_add")
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index cb846f606632b..0315c12b9bb8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -202,7 +202,8 @@ object AssertTrue {
case class CurrentDatabase() extends LeafExpression with Unevaluable {
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = false
- override def prettyName: String = "current_schema"
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_database")
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
index 5bd2ab6035e10..eefd21b236b7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
-import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult}
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
@@ -307,7 +307,10 @@ case class ToCharacter(left: Expression, right: Expression)
inputTypeCheck
}
}
- override def prettyName: String = "to_char"
+
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("to_char")
+
override def nullSafeEval(decimal: Any, format: Any): Any = {
val input = decimal.asInstanceOf[Decimal]
numberFormatter.format(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index ada0a73a67958..3cec83facd01d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
@@ -128,8 +128,12 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi
}
override def flatArguments: Iterator[Any] = Iterator(child)
+
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rand")
+
override def sql: String = {
- s"rand(${if (hideSeed) "" else child.sql})"
+ s"$prettyName(${if (hideSeed) "" else child.sql})"
}
override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 2fcc689b9df2b..776efbed273e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -134,7 +134,17 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i)))
}
val struct = CreateNamedStruct(fields.flatten.toImmutableArraySeq)
- KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
+ // For nested structs (and other complex types), this branch is called again with either a
+ // `GetStructField` or a `NamedLambdaVariable` expression. Even if the field for which this
+ // has been recursively called might have `nullable = false`, directly creating an `If`
+ // predicate would end up creating an expression with `nullable = true` (as the trueBranch is
+ // nullable). Hence, use the `expr.nullable` to create an `If` predicate only when the column
+ // is nullable.
+ if (expr.nullable) {
+ KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
+ } else {
+ KnownFloatingPointNormalized(struct)
+ }
case _ if expr.dataType.isInstanceOf[ArrayType] =>
val ArrayType(et, containsNull) = expr.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 1601d798283c9..c0cd976b9e9b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -260,19 +260,32 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) {
case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
- if (foldables.size > 1) {
+ if (foldables.nonEmpty) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y, f))
- val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
- if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y, f)), c, f)
+ val foldableValue = foldableExpr.eval(EmptyRow)
+ if (others.isEmpty) {
+ Literal.create(foldableValue, a.dataType)
+ } else if (foldableValue == 0) {
+ others.reduce((x, y) => Add(x, y, f))
+ } else {
+ Add(others.reduce((x, y) => Add(x, y, f)), Literal.create(foldableValue, a.dataType), f)
+ }
} else {
a
}
case m @ Multiply(_, _, f) if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
- if (foldables.size > 1) {
+ if (foldables.nonEmpty) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y, f))
- val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
- if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y, f)), c, f)
+ val foldableValue = foldableExpr.eval(EmptyRow)
+ if (others.isEmpty || (foldableValue == 0 && !m.nullable)) {
+ Literal.create(foldableValue, m.dataType)
+ } else if (foldableValue == 1) {
+ others.reduce((x, y) => Multiply(x, y, f))
+ } else {
+ Multiply(others.reduce((x, y) => Multiply(x, y, f)),
+ Literal.create(foldableValue, m.dataType), f)
+ }
} else {
m
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index f1d211f517789..3ecb680cf6427 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1407,10 +1407,13 @@ class AstBuilder extends DataTypeAstBuilder
* - INTERSECT [DISTINCT | ALL]
*/
override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
- val left = plan(ctx.left)
- val right = plan(ctx.right)
val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
- ctx.operator.getType match {
+ visitSetOperationImpl(plan(ctx.left), plan(ctx.right), all, ctx.operator.getType)
+ }
+
+ private def visitSetOperationImpl(
+ left: LogicalPlan, right: LogicalPlan, all: Boolean, operatorType: Int): LogicalPlan = {
+ operatorType match {
case SqlBaseParser.UNION if all =>
Union(left, right)
case SqlBaseParser.UNION =>
@@ -3253,7 +3256,7 @@ class AstBuilder extends DataTypeAstBuilder
} catch {
case e: SparkArithmeticException =>
throw new ParseException(
- errorClass = e.getErrorClass,
+ errorClass = e.getCondition,
messageParameters = e.getMessageParameters.asScala.toMap,
ctx)
}
@@ -3549,7 +3552,7 @@ class AstBuilder extends DataTypeAstBuilder
// Keep error class of SparkIllegalArgumentExceptions and enrich it with query context
case se: SparkIllegalArgumentException =>
val pe = new ParseException(
- errorClass = se.getErrorClass,
+ errorClass = se.getCondition,
messageParameters = se.getMessageParameters.asScala.toMap,
ctx)
pe.setStackTrace(se.getStackTrace)
@@ -5916,7 +5919,12 @@ class AstBuilder extends DataTypeAstBuilder
withUnpivot(c, left)
}.getOrElse(Option(ctx.sample).map { c =>
withSample(c, left)
- }.get))))
+ }.getOrElse(Option(ctx.joinRelation()).map { c =>
+ withJoinRelation(c, left)
+ }.getOrElse(Option(ctx.operator).map { c =>
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ visitSetOperationImpl(left, plan(ctx.right), all, c.getType)
+ }.get))))))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala
index 46f14876be363..8d88b05546ed2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala
@@ -127,7 +127,7 @@ object GeneratedColumn {
} catch {
case ex: AnalysisException =>
// Improve error message if possible
- if (ex.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") {
+ if (ex.getCondition == "UNRESOLVED_COLUMN.WITH_SUGGESTION") {
ex.messageParameters.get("objectName").foreach { unresolvedCol =>
val resolver = SQLConf.get.resolver
// Whether `col` = `unresolvedCol` taking into account case-sensitivity
@@ -144,7 +144,7 @@ object GeneratedColumn {
}
}
}
- if (ex.getErrorClass == "UNRESOLVED_ROUTINE") {
+ if (ex.getCondition == "UNRESOLVED_ROUTINE") {
// Cannot resolve function using built-in catalog
ex.messageParameters.get("routineName").foreach { fnName =>
throw unsupportedExpressionError(s"failed to resolve $fnName to a built-in function")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 22cc001c0c78e..0e02e4249addd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -358,6 +358,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
)
}
+ def trailingCommaInSelectError(origin: Origin): Throwable = {
+ new AnalysisException(
+ errorClass = "TRAILING_COMMA_IN_SELECT",
+ messageParameters = Map.empty,
+ origin = origin
+ )
+ }
+
def unresolvedUsingColForJoinError(
colName: String, suggestion: String, side: String): Throwable = {
new AnalysisException(
@@ -3380,8 +3388,9 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
def cannotModifyValueOfStaticConfigError(key: String): Throwable = {
new AnalysisException(
- errorClass = "_LEGACY_ERROR_TEMP_1325",
- messageParameters = Map("key" -> key))
+ errorClass = "CANNOT_MODIFY_CONFIG",
+ messageParameters = Map("key" -> toSQLConf(key), "docroot" -> SPARK_DOC_ROOT)
+ )
}
def cannotModifyValueOfSparkConfigError(key: String, docroot: String): Throwable = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index bc6c7681ea1a5..301880f1bfc61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2845,6 +2845,16 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}
+ def conflictingDirectoryStructuresError(
+ discoveredBasePaths: Seq[String]): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "CONFLICTING_DIRECTORY_STRUCTURES",
+ messageParameters = Map(
+ "discoveredBasePaths" -> discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n")
+ )
+ )
+ }
+
def conflictingPartitionColumnNamesError(
distinctPartColLists: Seq[String],
suspiciousPaths: Seq[Path]): SparkRuntimeException = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala
index f0c28c95046eb..7602366c71a65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala
@@ -33,7 +33,7 @@ class SqlScriptingException (
cause)
with SparkThrowable {
- override def getErrorClass: String = errorClass
+ override def getCondition: String = errorClass
override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 969eee4d912e4..08002887135ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -782,7 +782,7 @@ object SQLConf {
CollationFactory.fetchCollation(collationName)
true
} catch {
- case e: SparkException if e.getErrorClass == "COLLATION_INVALID_NAME" => false
+ case e: SparkException if e.getCondition == "COLLATION_INVALID_NAME" => false
}
},
"DEFAULT_COLLATION",
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java
index 0db155e88aea5..339f16407ae60 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java
@@ -80,7 +80,7 @@ public void testLoadWithoutConfig() {
SparkException exc = Assertions.assertThrows(CatalogNotFoundException.class,
() -> Catalogs.load("missing", conf));
- Assertions.assertEquals(exc.getErrorClass(), "CATALOG_NOT_FOUND");
+ Assertions.assertEquals(exc.getCondition(), "CATALOG_NOT_FOUND");
Assertions.assertEquals(exc.getMessageParameters().get("catalogName"), "`missing`");
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e23a753dafe8c..8409f454bfb88 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1832,4 +1832,14 @@ class AnalysisSuite extends AnalysisTest with Matchers {
preemptedError.clear()
assert(preemptedError.getErrorOpt().isEmpty)
}
+
+ test("SPARK-49782: ResolveDataFrameDropColumns rule resolves complex UnresolvedAttribute") {
+ val function = UnresolvedFunction("trim", Seq(UnresolvedAttribute("i")), isDistinct = false)
+ val addColumnF = Project(Seq(UnresolvedAttribute("i"), Alias(function, "f")()), testRelation5)
+ // Drop column "f" via ResolveDataFrameDropColumns rule.
+ val inputPlan = DataFrameDropColumns(Seq(UnresolvedAttribute("f")), addColumnF)
+ // The expected Project (root node) should only have column "i".
+ val expectedPlan = Project(Seq(UnresolvedAttribute("i")), addColumnF).analyze
+ checkAnalysis(inputPlan, expectedPlan)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 33b9fb488c94f..71744f4d15105 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -205,7 +205,7 @@ trait AnalysisTest extends PlanTest {
assert(e.message.contains(message))
}
if (condition.isDefined) {
- assert(e.getErrorClass == condition.get)
+ assert(e.getCondition == condition.get)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 3e9a93dc743df..6ee19bab5180a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -1133,7 +1133,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
}
}
if (!condition.isEmpty) {
- assert(e.getErrorClass == condition)
+ assert(e.getCondition == condition)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala
index e8239c7523948..f3817e4dd1a8b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala
@@ -106,7 +106,7 @@ class CSVExprUtilsSuite extends SparkFunSuite {
} catch {
case e: SparkIllegalArgumentException =>
assert(separatorStr.isEmpty)
- assert(e.getErrorClass === expectedErrorClass.get)
+ assert(e.getCondition === expectedErrorClass.get)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 35a27f41da80a..6bd5b457ea24e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -173,7 +173,7 @@ class EncoderResolutionSuite extends PlanTest {
val exception = intercept[SparkRuntimeException] {
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
- assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
+ assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION")
}
test("the real number of fields doesn't match encoder schema: tuple encoder") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index f73911d344d96..79c6d07d6d218 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -279,7 +279,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
// Check the error class only since the parameters may change depending on how we are running
// this test case.
val exception = intercept[SparkRuntimeException](toRow(encoder, null))
- assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
+ assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION")
}
test("RowEncoder should validate external type") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala
index 3aeb0c882ac3c..891e2d048b7a8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala
@@ -64,7 +64,7 @@ object BufferHolderSparkSubmitSuite extends Assertions {
val e1 = intercept[SparkIllegalArgumentException] {
holder.grow(-1)
}
- assert(e1.getErrorClass === "_LEGACY_ERROR_TEMP_3198")
+ assert(e1.getCondition === "_LEGACY_ERROR_TEMP_3198")
// while to reuse a buffer may happen, this test checks whether the buffer can be grown
holder.grow(ARRAY_MAX / 2)
@@ -82,6 +82,6 @@ object BufferHolderSparkSubmitSuite extends Assertions {
val e2 = intercept[SparkIllegalArgumentException] {
holder.grow(ARRAY_MAX + 1 - holder.totalSize())
}
- assert(e2.getErrorClass === "_LEGACY_ERROR_TEMP_3199")
+ assert(e2.getCondition === "_LEGACY_ERROR_TEMP_3199")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
index 454619a2133d9..21049ca3546dc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
@@ -124,5 +124,13 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
comparePlans(doubleOptimized, correctAnswer)
}
+
+ test("SPARK-49863: NormalizeFloatingNumbers preserves nullability for nested struct") {
+ val relation = LocalRelation($"a".double, $"b".string)
+ val nestedExpr = namedStruct("struct", namedStruct("double", relation.output.head))
+ .as("nestedExpr").toAttribute
+ val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr)
+ assert(nestedExpr.dataType == normalizedExpr.dataType)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
index f4b2fce74dc49..9090e0c7fc104 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -74,4 +75,35 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("SPARK-49915: Handle zero and one in associative operators") {
+ val originalQuery =
+ testRelation.select(
+ $"a" + 0,
+ Literal(-3) + $"a" + 3,
+ $"b" * 0 * 1 * 2 * 3,
+ Count($"b") * 0,
+ $"b" * 1 * 1,
+ ($"b" + 0) * 1 * 2 * 3 * 4,
+ $"a" + 0 + $"b" + 0 + $"c" + 0,
+ $"a" + 0 + $"b" * 1 + $"c" + 0
+ )
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ $"a".as("(a + 0)"),
+ $"a".as("((-3 + a) + 3)"),
+ ($"b" * 0).as("((((b * 0) * 1) * 2) * 3)"),
+ Literal(0L).as("(count(b) * 0)"),
+ $"b".as("((b * 1) * 1)"),
+ ($"b" * 24).as("(((((b + 0) * 1) * 2) * 3) * 4)"),
+ ($"a" + $"b" + $"c").as("""(((((a + 0) + b) + 0) + c) + 0)"""),
+ ($"a" + $"b" + $"c").as("((((a + 0) + (b * 1)) + c) + 0)")
+ ).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index b7e2490b552cc..926beacc592a5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -3065,7 +3065,7 @@ class DDLParserSuite extends AnalysisTest {
s"(id BIGINT GENERATED ALWAYS AS IDENTITY $identitySpecStr, val INT) USING foo"
)
}
- assert(exception.getErrorClass === "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION")
+ assert(exception.getCondition === "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index 2972ba2db21de..2e702e5642a92 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -50,7 +50,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val e = intercept[ParseException] {
parseScript(sqlScriptText)
}
- assert(e.getErrorClass === "PARSE_SYNTAX_ERROR")
+ assert(e.getCondition === "PARSE_SYNTAX_ERROR")
assert(e.getMessage.contains("Syntax error"))
assert(e.getMessage.contains("SELECT"))
}
@@ -90,7 +90,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val e = intercept[ParseException] {
parseScript(sqlScriptText)
}
- assert(e.getErrorClass === "PARSE_SYNTAX_ERROR")
+ assert(e.getCondition === "PARSE_SYNTAX_ERROR")
assert(e.getMessage.contains("Syntax error"))
assert(e.getMessage.contains("at or near ';'"))
}
@@ -105,7 +105,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val e = intercept[ParseException] {
parseScript(sqlScriptText)
}
- assert(e.getErrorClass === "PARSE_SYNTAX_ERROR")
+ assert(e.getCondition === "PARSE_SYNTAX_ERROR")
assert(e.getMessage.contains("Syntax error"))
assert(e.getMessage.contains("at or near end of input"))
}
@@ -367,7 +367,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val e = intercept[ParseException] {
parseScript(sqlScriptText)
}
- assert(e.getErrorClass === "PARSE_SYNTAX_ERROR")
+ assert(e.getCondition === "PARSE_SYNTAX_ERROR")
assert(e.getMessage.contains("Syntax error"))
}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/commands.proto b/sql/connect/common/src/main/protobuf/spark/connect/commands.proto
index 71189a3c43a19..a01d4369a7aed 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -507,6 +507,9 @@ message CheckpointCommand {
// (Required) Whether to checkpoint this dataframe immediately.
bool eager = 3;
+
+ // (Optional) For local checkpoint, the storage level to use.
+ optional StorageLevel storage_level = 4;
}
message MergeIntoTableCommand {
diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
index 102f736c62ef6..6668692f6cf1d 100644
--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0]
+Aggregate [array_agg(a#0, 0, 0) AS array_agg(a)#0]
+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain
index 5305b346c4f2d..be039d62a5494 100644
--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain
+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain
@@ -1,2 +1,2 @@
-Project [current_date(Some(America/Los_Angeles)) AS current_date()#0]
+Project [curdate(Some(America/Los_Angeles)) AS curdate()#0]
+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0]
diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain
index 481c0a478c8df..93dfac524d9a1 100644
--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain
+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain
@@ -1,2 +1,2 @@
-Project [current_schema() AS current_schema()#0]
+Project [current_database() AS current_database()#0]
+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain
index 66325085b9c14..319428541760d 100644
--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain
+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain
@@ -1,2 +1,2 @@
-Project [date_add(d#0, 2) AS date_add(d, 2)#0]
+Project [dateadd(d#0, 2) AS dateadd(d, 2)#0]
+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0]
diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain
index 81c81e95c2bdd..5854d2c7fa6be 100644
--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain
+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain
@@ -1,2 +1,2 @@
-Project [random(1) AS rand(1)#0]
+Project [random(1) AS random(1)#0]
+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain
index f0d9cacc61ac5..cc5149bfed863 100644
--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain
+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain
@@ -1,2 +1,2 @@
-Project [to_char(cast(b#0 as decimal(30,15)), $99.99) AS to_char(b, $99.99)#0]
+Project [to_varchar(cast(b#0 as decimal(30,15)), $99.99) AS to_varchar(b, $99.99)#0]
+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml
index d0d982934d2c7..f2a7f1b1da9d9 100644
--- a/sql/connect/server/pom.xml
+++ b/sql/connect/server/pom.xml
@@ -52,6 +52,10 @@
spark-connect-common_${scala.binary.version}
${project.version}
+
+ org.apache.spark
+ spark-connect-shims_${scala.binary.version}
+
com.google.guava
guava
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 231e54ff77d29..4e6994f9c2f8b 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -3118,7 +3118,7 @@ class SparkConnectPlanner(
.newBuilder()
exception_builder
.setExceptionMessage(e.toString())
- .setErrorClass(e.getErrorClass)
+ .setErrorClass(e.getCondition)
val stackTrace = Option(ExceptionUtils.getStackTrace(e))
stackTrace.foreach { s =>
@@ -3354,9 +3354,18 @@ class SparkConnectPlanner(
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
val target = Dataset
.ofRows(session, transformRelation(checkpointCommand.getRelation))
- val checkpointed = target.checkpoint(
- eager = checkpointCommand.getEager,
- reliableCheckpoint = !checkpointCommand.getLocal)
+ val checkpointed = if (checkpointCommand.getLocal) {
+ if (checkpointCommand.hasStorageLevel) {
+ target.localCheckpoint(
+ eager = checkpointCommand.getEager,
+ storageLevel =
+ StorageLevelProtoConverter.toStorageLevel(checkpointCommand.getStorageLevel))
+ } else {
+ target.localCheckpoint(eager = checkpointCommand.getEager)
+ }
+ } else {
+ target.checkpoint(eager = checkpointCommand.getEager)
+ }
val dfId = UUID.randomUUID().toString
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 0468a55e23027..e62c19b66c8e5 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -345,7 +345,7 @@ object SparkConnectService extends Logging {
val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore]
listener = new SparkConnectServerListener(kvStore, sc.conf)
sc.listenerBus.addToStatusQueue(listener)
- uiTab = if (sc.getConf.get(UI_ENABLED)) {
+ uiTab = if (sc.conf.get(UI_ENABLED)) {
Some(
new SparkConnectServerTab(
new SparkConnectServerAppStatusStore(kvStore),
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index f1636ed1ef092..837d4a4d3ee78 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -114,8 +114,8 @@ private[connect] object ErrorUtils extends Logging {
case sparkThrowable: SparkThrowable =>
val sparkThrowableBuilder = FetchErrorDetailsResponse.SparkThrowable
.newBuilder()
- if (sparkThrowable.getErrorClass != null) {
- sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
+ if (sparkThrowable.getCondition != null) {
+ sparkThrowableBuilder.setErrorClass(sparkThrowable.getCondition)
}
for (queryCtx <- sparkThrowable.getQueryContext) {
val builder = FetchErrorDetailsResponse.QueryContext
@@ -193,7 +193,7 @@ private[connect] object ErrorUtils extends Logging {
if (state != null && state.nonEmpty) {
errorInfo.putMetadata("sqlState", state)
}
- val errorClass = e.getErrorClass
+ val errorClass = e.getCondition
if (errorClass != null && errorClass.nonEmpty) {
val messageParameters = JsonMethods.compact(
JsonMethods.render(map2jvalue(e.getMessageParameters.asScala.toMap)))
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index 42bb93de05e26..1f522ea28b761 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -37,7 +37,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
val exGetOrCreate = intercept[SparkSQLException] {
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
}
- assert(exGetOrCreate.getErrorClass == "INVALID_HANDLE.FORMAT")
+ assert(exGetOrCreate.getCondition == "INVALID_HANDLE.FORMAT")
}
test(
@@ -72,7 +72,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
key,
Some(sessionHolder.session.sessionUUID + "invalid"))
}
- assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_CHANGED")
+ assert(exGet.getCondition == "INVALID_HANDLE.SESSION_CHANGED")
}
test(
@@ -85,12 +85,12 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
val exGetOrCreate = intercept[SparkSQLException] {
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
}
- assert(exGetOrCreate.getErrorClass == "INVALID_HANDLE.SESSION_CLOSED")
+ assert(exGetOrCreate.getCondition == "INVALID_HANDLE.SESSION_CLOSED")
val exGet = intercept[SparkSQLException] {
SparkConnectService.sessionManager.getIsolatedSession(key, None)
}
- assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_CLOSED")
+ assert(exGet.getCondition == "INVALID_HANDLE.SESSION_CLOSED")
val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
assert(sessionGetIfPresent.isEmpty)
@@ -102,7 +102,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
val exGet = intercept[SparkSQLException] {
SparkConnectService.sessionManager.getIsolatedSession(key, None)
}
- assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_NOT_FOUND")
+ assert(exGet.getCondition == "INVALID_HANDLE.SESSION_NOT_FOUND")
val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
assert(sessionGetIfPresent.isEmpty)
diff --git a/sql/connect/shims/README.md b/sql/connect/shims/README.md
new file mode 100644
index 0000000000000..07b593dd04b4b
--- /dev/null
+++ b/sql/connect/shims/README.md
@@ -0,0 +1 @@
+This module defines shims used by the interface defined in sql/api.
diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml
new file mode 100644
index 0000000000000..d177b4a9971f5
--- /dev/null
+++ b/sql/connect/shims/pom.xml
@@ -0,0 +1,48 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.13
+ 4.0.0-SNAPSHOT
+ ../../../pom.xml
+
+
+ spark-connect-shims_2.13
+ jar
+ Spark Project Connect Shims
+ https://spark.apache.org/
+
+ connect-shims
+
+
+
+
+ org.scala-lang
+ scala-library
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala
new file mode 100644
index 0000000000000..45fae00247485
--- /dev/null
+++ b/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.api.java
+
+class JavaRDD[T]
diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala
new file mode 100644
index 0000000000000..b23f83fa9185c
--- /dev/null
+++ b/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rdd
+
+class RDD[T]
diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala
new file mode 100644
index 0000000000000..813b8e4859c28
--- /dev/null
+++ b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark
+
+class SparkContext
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 972cf76d27535..16236940fe072 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -79,6 +79,12 @@
${project.version}
test-jar
test
+
+
+ org.apache.spark
+ spark-connect-shims_${scala.binary.version}
+
+