Skip to content

Commit

Permalink
[GLUTEN-3450][CH] Support allowPrecisionLoss=false (#3463)
Browse files Browse the repository at this point in the history
Support allowPrecisionLoss=false
  • Loading branch information
loneylee authored Oct 24, 2023
1 parent 2aaf07b commit 7d5e8fb
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def shuffleSupportedCodec(): Set[String] = GLUTEN_CLICKHOUSE_SHUFFLE_SUPPORTED_CODEC
override def needOutputSchemaForPlan(): Boolean = true

override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package io.glutenproject.execution
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.{col, rand, when}
import org.apache.spark.sql.types._

import java.io.File
import java.util

case class DataTypesWithNonPrimitiveType(
string_field: String,
int_field: java.lang.Integer,
Expand All @@ -40,11 +43,17 @@ class GlutenClickHouseDecimalSuite
rootPath + "../../../../gluten-core/src/test/resources/tpch-queries"
override protected val queriesResults: String = rootPath + "queries-output"

override protected def createTPCHNullableTables(): Unit = {}

override protected def createTPCHNotNullTables(): Unit = {}

override protected def sparkConf: SparkConf = super.sparkConf
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "sort")
.set("spark.io.compression.codec", "snappy")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.gluten.sql.columnar.backend.ch.use.v2", "false")
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
}

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -54,6 +63,246 @@ class GlutenClickHouseDecimalSuite
}

private val decimalTable: String = "decimal_table"
private val decimalTPCHTables: Seq[DecimalType] = Seq.apply(DecimalType.apply(18, 8))

override protected val createNullableTables = true

override protected def createTPCHNullableTables(): Unit = {
decimalTPCHTables.foreach(createDecimalTables)
}

private def createDecimalTables(dataType: DecimalType): Unit = {
spark.sql(s"DROP database IF EXISTS decimal_${dataType.precision}_${dataType.scale}")
spark.sql(s"create database IF not EXISTS decimal_${dataType.precision}_${dataType.scale}")
spark.sql(s"use decimal_${dataType.precision}_${dataType.scale}")

// first process the parquet data to:
// 1. make every column nullable in schema (optional rather than required)
// 2. salt some null values randomly
val saltedTablesPath = tablesPath + s"-decimal_${dataType.precision}_${dataType.scale}"
withSQLConf(vanillaSparkConfs(): _*) {
Seq("customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier")
.map(
tableName => {
val originTablePath = tablesPath + "/" + tableName
spark.read.parquet(originTablePath).createTempView(tableName + "_ori")

val sql = tableName match {
case "customer" =>
s"""
|select
| c_custkey,c_name,c_address,c_nationkey,c_phone,
| cast(c_acctbal as decimal(${dataType.precision},${dataType.scale})),
| c_mktsegment,c_comment
|from ${tableName}_ori""".stripMargin
case "lineitem" =>
s"""
|select
| l_orderkey,l_partkey,l_suppkey,l_linenumber,
| cast(l_quantity as decimal(${dataType.precision},${dataType.scale})),
| cast(l_extendedprice as decimal(${dataType.precision},${dataType.scale})),
| cast(l_discount as decimal(${dataType.precision},${dataType.scale})),
| cast(l_tax as decimal(${dataType.precision},${dataType.scale})),
| l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate,
| l_shipinstruct,l_shipmode,l_comment
|from ${tableName}_ori """.stripMargin
case "orders" =>
s"""
|select
| o_orderkey,o_custkey,o_orderstatus,
| cast(o_totalprice as decimal(${dataType.precision},${dataType.scale})),
| o_orderdate,
| o_orderpriority,o_clerk,o_shippriority,o_comment
|from ${tableName}_ori
|""".stripMargin
case "part" =>
s"""
|select
| p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,
| cast(p_retailprice as decimal(${dataType.precision},${dataType.scale})),
| p_comment
|from ${tableName}_ori
|""".stripMargin
case "partsupp" =>
s"""
|select
| ps_partkey,ps_suppkey,ps_availqty,
| cast(ps_supplycost as decimal(${dataType.precision},${dataType.scale})),
| ps_comment
|from ${tableName}_ori
|""".stripMargin
case "supplier" =>
s"""
|select
| s_suppkey,s_name,s_address,s_nationkey,s_phone,
| cast(s_acctbal as decimal(${dataType.precision},${dataType.scale})),s_comment
|from ${tableName}_ori
|""".stripMargin
case _ => s"select * from ${tableName}_ori"
}

val df = spark.sql(sql).toDF()
var salted_df: Option[DataFrame] = None
for (c <- df.schema) {
salted_df = Some((salted_df match {
case Some(x) => x
case None => df
}).withColumn(c.name, when(rand() < 0.1, null).otherwise(col(c.name))))
}

val currentSaltedTablePath = saltedTablesPath + "/" + tableName
val file = new File(currentSaltedTablePath)
if (file.exists()) {
file.delete()
}

salted_df.get.write.parquet(currentSaltedTablePath)
})
}

val customerData = saltedTablesPath + "/customer"
spark.sql(s"DROP TABLE IF EXISTS customer")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS customer (
| c_custkey bigint,
| c_name string,
| c_address string,
| c_nationkey bigint,
| c_phone string,
| c_acctbal decimal(${dataType.precision},${dataType.scale}),
| c_mktsegment string,
| c_comment string)
| USING PARQUET LOCATION '$customerData'
|""".stripMargin)

val lineitemData = saltedTablesPath + "/lineitem"
spark.sql(s"DROP TABLE IF EXISTS lineitem")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS lineitem (
| l_orderkey bigint,
| l_partkey bigint,
| l_suppkey bigint,
| l_linenumber bigint,
| l_quantity decimal(${dataType.precision},${dataType.scale}),
| l_extendedprice decimal(${dataType.precision},${dataType.scale}),
| l_discount decimal(${dataType.precision},${dataType.scale}),
| l_tax decimal(${dataType.precision},${dataType.scale}),
| l_returnflag string,
| l_linestatus string,
| l_shipdate date,
| l_commitdate date,
| l_receiptdate date,
| l_shipinstruct string,
| l_shipmode string,
| l_comment string)
| USING PARQUET LOCATION '$lineitemData'
|""".stripMargin)

val nationData = saltedTablesPath + "/nation"
spark.sql(s"DROP TABLE IF EXISTS nation")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS nation (
| n_nationkey bigint,
| n_name string,
| n_regionkey bigint,
| n_comment string)
| USING PARQUET LOCATION '$nationData'
|""".stripMargin)

val regionData = saltedTablesPath + "/region"
spark.sql(s"DROP TABLE IF EXISTS region")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS region (
| r_regionkey bigint,
| r_name string,
| r_comment string)
| USING PARQUET LOCATION '$regionData'
|""".stripMargin)

val ordersData = saltedTablesPath + "/orders"
spark.sql(s"DROP TABLE IF EXISTS orders")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS orders (
| o_orderkey bigint,
| o_custkey bigint,
| o_orderstatus string,
| o_totalprice decimal(${dataType.precision},${dataType.scale}),
| o_orderdate date,
| o_orderpriority string,
| o_clerk string,
| o_shippriority bigint,
| o_comment string)
| USING PARQUET LOCATION '$ordersData'
|""".stripMargin)

val partData = saltedTablesPath + "/part"
spark.sql(s"DROP TABLE IF EXISTS part")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS part (
| p_partkey bigint,
| p_name string,
| p_mfgr string,
| p_brand string,
| p_type string,
| p_size bigint,
| p_container string,
| p_retailprice decimal(${dataType.precision},${dataType.scale}),
| p_comment string)
| USING PARQUET LOCATION '$partData'
|""".stripMargin)

val partsuppData = saltedTablesPath + "/partsupp"
spark.sql(s"DROP TABLE IF EXISTS partsupp")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS partsupp (
| ps_partkey bigint,
| ps_suppkey bigint,
| ps_availqty bigint,
| ps_supplycost decimal(${dataType.precision},${dataType.scale}),
| ps_comment string)
| USING PARQUET LOCATION '$partsuppData'
|""".stripMargin)

val supplierData = saltedTablesPath + "/supplier"
spark.sql(s"DROP TABLE IF EXISTS supplier")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS supplier (
| s_suppkey bigint,
| s_name string,
| s_address string,
| s_nationkey bigint,
| s_phone string,
| s_acctbal decimal(${dataType.precision},${dataType.scale}),
| s_comment string)
| USING PARQUET LOCATION '$supplierData'
|""".stripMargin)

val result = spark
.sql(s"""
| show tables;
|""".stripMargin)
.collect()
assert(result.size == 16)
spark.sql(s"use default")
}

override protected def runTPCHQuery(
queryNum: Int,
tpchQueries: String = tpchQueries,
queriesResults: String = queriesResults,
compareResult: Boolean = true,
noFallBack: Boolean = true)(customCheck: DataFrame => Unit): Unit = {
decimalTPCHTables.foreach(
decimalType => {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
compareTPCHQueryAgainstVanillaSpark(queryNum, tpchQueries, customCheck, noFallBack)
spark.sql(s"use default")
})
}

test("TPCH Q20") {
runTPCHQuery(20)(_ => {})
}

test("fix decimal precision overflow") {
val sql =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
.set("spark.gluten.sql.enable.native.validation", "false")
.set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")
.set("spark.sql.warehouse.dir", warehouse)
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
/* .set("spark.sql.catalogImplementation", "hive")
.set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${
metaStorePathAbsolute + "/metastore_db"};create=true") */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ abstract class GlutenClickHouseTPCHAbstractSuite
.set("spark.gluten.sql.enable.native.validation", "false")
.set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")
.set("spark.sql.warehouse.dir", warehouse)
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
/* .set("spark.sql.catalogImplementation", "hive")
.set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${
metaStorePathAbsolute + "/metastore_db"};create=true") */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import scala.util.control.Breaks.breakable
Expand Down Expand Up @@ -322,4 +323,6 @@ object BackendSettings extends BackendSettingsApi {
override def requiredChildOrderingForWindow(): Boolean = true

override def staticPartitionWriteOnly(): Boolean = true

override def allowDecimalArithmetic: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ trait BackendSettingsApi {
/** Get the config prefix for each backend */
def getBackendConfigPrefix: String

def allowDecimalArithmetic: Boolean = true

def rescaleDecimalIntegralExpression(): Boolean = false

def shuffleSupportedCodec(): Set[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,13 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer(expr.children.head, attributeSeq),
expr)
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
if (!conf.decimalOperationsAllowPrecisionLoss) {
// PrecisionLoss=true: velox support / ch not support
// PrecisionLoss=false: velox not support / ch support
// TODO ch support PrecisionLoss=true
if (!BackendsApiManager.getSettings.allowDecimalArithmetic) {
throw new UnsupportedOperationException(
s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} false mode")
s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " +
s"${conf.decimalOperationsAllowPrecisionLoss} mode")
}
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
DecimalArithmeticUtil.rescaleLiteral(b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ trait GlutenTestsTrait extends GlutenTestsCommonTrait {
.config(GlutenConfig.GLUTEN_LIB_PATH, SystemParameters.getClickHouseLibPath)
.config("spark.unsafe.exceptionOnMemoryLeak", "true")
.config(GlutenConfig.UT_STATISTIC.key, "true")
.config("spark.sql.decimalOperations.allowPrecisionLoss", "false")
.getOrCreate()
} else {
sparkBuilder
Expand Down

0 comments on commit 7d5e8fb

Please sign in to comment.