diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala index b7774575c1d1..4714215343de 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala @@ -203,4 +203,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging { override def shuffleSupportedCodec(): Set[String] = GLUTEN_CLICKHOUSE_SHUFFLE_SUPPORTED_CODEC override def needOutputSchemaForPlan(): Boolean = true + + override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDecimalSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDecimalSuite.scala index 698ffaeaacc3..d7b4a0c57d48 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDecimalSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDecimalSuite.scala @@ -19,9 +19,12 @@ package io.glutenproject.execution import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.{col, rand, when} import org.apache.spark.sql.types._ +import java.io.File import java.util + case class DataTypesWithNonPrimitiveType( string_field: String, int_field: java.lang.Integer, @@ -40,11 +43,17 @@ class GlutenClickHouseDecimalSuite rootPath + "../../../../gluten-core/src/test/resources/tpch-queries" override protected val queriesResults: String = rootPath + "queries-output" - override protected def createTPCHNullableTables(): Unit = {} - override protected def createTPCHNotNullTables(): Unit = {} - override protected def sparkConf: SparkConf = super.sparkConf + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.shuffle.manager", "sort") + .set("spark.io.compression.codec", "snappy") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.autoBroadcastJoinThreshold", "10MB") + .set("spark.gluten.sql.columnar.backend.ch.use.v2", "false") + .set("spark.sql.decimalOperations.allowPrecisionLoss", "false") + } override def beforeAll(): Unit = { super.beforeAll() @@ -54,6 +63,246 @@ class GlutenClickHouseDecimalSuite } private val decimalTable: String = "decimal_table" + private val decimalTPCHTables: Seq[DecimalType] = Seq.apply(DecimalType.apply(18, 8)) + + override protected val createNullableTables = true + + override protected def createTPCHNullableTables(): Unit = { + decimalTPCHTables.foreach(createDecimalTables) + } + + private def createDecimalTables(dataType: DecimalType): Unit = { + spark.sql(s"DROP database IF EXISTS decimal_${dataType.precision}_${dataType.scale}") + spark.sql(s"create database IF not EXISTS decimal_${dataType.precision}_${dataType.scale}") + spark.sql(s"use decimal_${dataType.precision}_${dataType.scale}") + + // first process the parquet data to: + // 1. make every column nullable in schema (optional rather than required) + // 2. salt some null values randomly + val saltedTablesPath = tablesPath + s"-decimal_${dataType.precision}_${dataType.scale}" + withSQLConf(vanillaSparkConfs(): _*) { + Seq("customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier") + .map( + tableName => { + val originTablePath = tablesPath + "/" + tableName + spark.read.parquet(originTablePath).createTempView(tableName + "_ori") + + val sql = tableName match { + case "customer" => + s""" + |select + | c_custkey,c_name,c_address,c_nationkey,c_phone, + | cast(c_acctbal as decimal(${dataType.precision},${dataType.scale})), + | c_mktsegment,c_comment + |from ${tableName}_ori""".stripMargin + case "lineitem" => + s""" + |select + | l_orderkey,l_partkey,l_suppkey,l_linenumber, + | cast(l_quantity as decimal(${dataType.precision},${dataType.scale})), + | cast(l_extendedprice as decimal(${dataType.precision},${dataType.scale})), + | cast(l_discount as decimal(${dataType.precision},${dataType.scale})), + | cast(l_tax as decimal(${dataType.precision},${dataType.scale})), + | l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate, + | l_shipinstruct,l_shipmode,l_comment + |from ${tableName}_ori """.stripMargin + case "orders" => + s""" + |select + | o_orderkey,o_custkey,o_orderstatus, + | cast(o_totalprice as decimal(${dataType.precision},${dataType.scale})), + | o_orderdate, + | o_orderpriority,o_clerk,o_shippriority,o_comment + |from ${tableName}_ori + |""".stripMargin + case "part" => + s""" + |select + | p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container, + | cast(p_retailprice as decimal(${dataType.precision},${dataType.scale})), + | p_comment + |from ${tableName}_ori + |""".stripMargin + case "partsupp" => + s""" + |select + | ps_partkey,ps_suppkey,ps_availqty, + | cast(ps_supplycost as decimal(${dataType.precision},${dataType.scale})), + | ps_comment + |from ${tableName}_ori + |""".stripMargin + case "supplier" => + s""" + |select + | s_suppkey,s_name,s_address,s_nationkey,s_phone, + | cast(s_acctbal as decimal(${dataType.precision},${dataType.scale})),s_comment + |from ${tableName}_ori + |""".stripMargin + case _ => s"select * from ${tableName}_ori" + } + + val df = spark.sql(sql).toDF() + var salted_df: Option[DataFrame] = None + for (c <- df.schema) { + salted_df = Some((salted_df match { + case Some(x) => x + case None => df + }).withColumn(c.name, when(rand() < 0.1, null).otherwise(col(c.name)))) + } + + val currentSaltedTablePath = saltedTablesPath + "/" + tableName + val file = new File(currentSaltedTablePath) + if (file.exists()) { + file.delete() + } + + salted_df.get.write.parquet(currentSaltedTablePath) + }) + } + + val customerData = saltedTablesPath + "/customer" + spark.sql(s"DROP TABLE IF EXISTS customer") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS customer ( + | c_custkey bigint, + | c_name string, + | c_address string, + | c_nationkey bigint, + | c_phone string, + | c_acctbal decimal(${dataType.precision},${dataType.scale}), + | c_mktsegment string, + | c_comment string) + | USING PARQUET LOCATION '$customerData' + |""".stripMargin) + + val lineitemData = saltedTablesPath + "/lineitem" + spark.sql(s"DROP TABLE IF EXISTS lineitem") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS lineitem ( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity decimal(${dataType.precision},${dataType.scale}), + | l_extendedprice decimal(${dataType.precision},${dataType.scale}), + | l_discount decimal(${dataType.precision},${dataType.scale}), + | l_tax decimal(${dataType.precision},${dataType.scale}), + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string) + | USING PARQUET LOCATION '$lineitemData' + |""".stripMargin) + + val nationData = saltedTablesPath + "/nation" + spark.sql(s"DROP TABLE IF EXISTS nation") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS nation ( + | n_nationkey bigint, + | n_name string, + | n_regionkey bigint, + | n_comment string) + | USING PARQUET LOCATION '$nationData' + |""".stripMargin) + + val regionData = saltedTablesPath + "/region" + spark.sql(s"DROP TABLE IF EXISTS region") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS region ( + | r_regionkey bigint, + | r_name string, + | r_comment string) + | USING PARQUET LOCATION '$regionData' + |""".stripMargin) + + val ordersData = saltedTablesPath + "/orders" + spark.sql(s"DROP TABLE IF EXISTS orders") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS orders ( + | o_orderkey bigint, + | o_custkey bigint, + | o_orderstatus string, + | o_totalprice decimal(${dataType.precision},${dataType.scale}), + | o_orderdate date, + | o_orderpriority string, + | o_clerk string, + | o_shippriority bigint, + | o_comment string) + | USING PARQUET LOCATION '$ordersData' + |""".stripMargin) + + val partData = saltedTablesPath + "/part" + spark.sql(s"DROP TABLE IF EXISTS part") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS part ( + | p_partkey bigint, + | p_name string, + | p_mfgr string, + | p_brand string, + | p_type string, + | p_size bigint, + | p_container string, + | p_retailprice decimal(${dataType.precision},${dataType.scale}), + | p_comment string) + | USING PARQUET LOCATION '$partData' + |""".stripMargin) + + val partsuppData = saltedTablesPath + "/partsupp" + spark.sql(s"DROP TABLE IF EXISTS partsupp") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS partsupp ( + | ps_partkey bigint, + | ps_suppkey bigint, + | ps_availqty bigint, + | ps_supplycost decimal(${dataType.precision},${dataType.scale}), + | ps_comment string) + | USING PARQUET LOCATION '$partsuppData' + |""".stripMargin) + + val supplierData = saltedTablesPath + "/supplier" + spark.sql(s"DROP TABLE IF EXISTS supplier") + spark.sql(s""" + | CREATE TABLE IF NOT EXISTS supplier ( + | s_suppkey bigint, + | s_name string, + | s_address string, + | s_nationkey bigint, + | s_phone string, + | s_acctbal decimal(${dataType.precision},${dataType.scale}), + | s_comment string) + | USING PARQUET LOCATION '$supplierData' + |""".stripMargin) + + val result = spark + .sql(s""" + | show tables; + |""".stripMargin) + .collect() + assert(result.size == 16) + spark.sql(s"use default") + } + + override protected def runTPCHQuery( + queryNum: Int, + tpchQueries: String = tpchQueries, + queriesResults: String = queriesResults, + compareResult: Boolean = true, + noFallBack: Boolean = true)(customCheck: DataFrame => Unit): Unit = { + decimalTPCHTables.foreach( + decimalType => { + spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}") + compareTPCHQueryAgainstVanillaSpark(queryNum, tpchQueries, customCheck, noFallBack) + spark.sql(s"use default") + }) + } + + test("TPCH Q20") { + runTPCHQuery(20)(_ => {}) + } test("fix decimal precision overflow") { val sql = diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala index 949622193235..4eee83c6e997 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -176,6 +176,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite .set("spark.gluten.sql.enable.native.validation", "false") .set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true") .set("spark.sql.warehouse.dir", warehouse) + .set("spark.sql.decimalOperations.allowPrecisionLoss", "false") /* .set("spark.sql.catalogImplementation", "hive") .set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${ metaStorePathAbsolute + "/metastore_db"};create=true") */ diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala index 6e1bc17fce48..79fdc0349dc8 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala @@ -520,6 +520,7 @@ abstract class GlutenClickHouseTPCHAbstractSuite .set("spark.gluten.sql.enable.native.validation", "false") .set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true") .set("spark.sql.warehouse.dir", warehouse) + .set("spark.sql.decimalOperations.allowPrecisionLoss", "false") /* .set("spark.sql.catalogImplementation", "hive") .set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${ metaStorePathAbsolute + "/metastore_db"};create=true") */ diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala index 2790820dd8da..e94efb536c76 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.expression.UDFResolver +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import scala.util.control.Breaks.breakable @@ -322,4 +323,6 @@ object BackendSettings extends BackendSettingsApi { override def requiredChildOrderingForWindow(): Boolean = true override def staticPartitionWriteOnly(): Boolean = true + + override def allowDecimalArithmetic: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss } diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala index 461faccdb13f..502ccc1669e3 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala @@ -87,6 +87,8 @@ trait BackendSettingsApi { /** Get the config prefix for each backend */ def getBackendConfigPrefix: String + def allowDecimalArithmetic: Boolean = true + def rescaleDecimalIntegralExpression(): Boolean = false def shuffleSupportedCodec(): Set[String] diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala index 37b137da028f..ff13709613a9 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala @@ -402,9 +402,13 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformer(expr.children.head, attributeSeq), expr) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => - if (!conf.decimalOperationsAllowPrecisionLoss) { + // PrecisionLoss=true: velox support / ch not support + // PrecisionLoss=false: velox not support / ch support + // TODO ch support PrecisionLoss=true + if (!BackendsApiManager.getSettings.allowDecimalArithmetic) { throw new UnsupportedOperationException( - s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} false mode") + s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " + + s"${conf.decimalOperationsAllowPrecisionLoss} mode") } val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) { DecimalArithmeticUtil.rescaleLiteral(b) diff --git a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsTrait.scala b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsTrait.scala index 7ca393f7666b..fa8c146d35a6 100644 --- a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsTrait.scala +++ b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsTrait.scala @@ -117,6 +117,7 @@ trait GlutenTestsTrait extends GlutenTestsCommonTrait { .config(GlutenConfig.GLUTEN_LIB_PATH, SystemParameters.getClickHouseLibPath) .config("spark.unsafe.exceptionOnMemoryLeak", "true") .config(GlutenConfig.UT_STATISTIC.key, "true") + .config("spark.sql.decimalOperations.allowPrecisionLoss", "false") .getOrCreate() } else { sparkBuilder