Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-3450][CH] Support allowPrecisionLoss=false #3463

Merged
merged 9 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def shuffleSupportedCodec(): Set[String] = GLUTEN_CLICKHOUSE_SHUFFLE_SUPPORTED_CODEC
override def needOutputSchemaForPlan(): Boolean = true

override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package io.glutenproject.execution
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.{col, rand, when}
import org.apache.spark.sql.types._

import java.io.File
import java.util

case class DataTypesWithNonPrimitiveType(
string_field: String,
int_field: java.lang.Integer,
Expand All @@ -40,11 +43,17 @@ class GlutenClickHouseDecimalSuite
rootPath + "../../../../gluten-core/src/test/resources/tpch-queries"
override protected val queriesResults: String = rootPath + "queries-output"

override protected def createTPCHNullableTables(): Unit = {}

override protected def createTPCHNotNullTables(): Unit = {}

override protected def sparkConf: SparkConf = super.sparkConf
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "sort")
.set("spark.io.compression.codec", "snappy")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.gluten.sql.columnar.backend.ch.use.v2", "false")
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
}

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -54,6 +63,246 @@ class GlutenClickHouseDecimalSuite
}

private val decimalTable: String = "decimal_table"
private val decimalTPCHTables: Seq[DecimalType] = Seq.apply(DecimalType.apply(18, 8))

override protected val createNullableTables = true

override protected def createTPCHNullableTables(): Unit = {
decimalTPCHTables.foreach(createDecimalTables)
}

private def createDecimalTables(dataType: DecimalType): Unit = {
spark.sql(s"DROP database IF EXISTS decimal_${dataType.precision}_${dataType.scale}")
spark.sql(s"create database IF not EXISTS decimal_${dataType.precision}_${dataType.scale}")
spark.sql(s"use decimal_${dataType.precision}_${dataType.scale}")

// first process the parquet data to:
// 1. make every column nullable in schema (optional rather than required)
// 2. salt some null values randomly
val saltedTablesPath = tablesPath + s"-decimal_${dataType.precision}_${dataType.scale}"
withSQLConf(vanillaSparkConfs(): _*) {
Seq("customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier")
.map(
tableName => {
val originTablePath = tablesPath + "/" + tableName
spark.read.parquet(originTablePath).createTempView(tableName + "_ori")

val sql = tableName match {
case "customer" =>
s"""
|select
| c_custkey,c_name,c_address,c_nationkey,c_phone,
| cast(c_acctbal as decimal(${dataType.precision},${dataType.scale})),
| c_mktsegment,c_comment
|from ${tableName}_ori""".stripMargin
case "lineitem" =>
s"""
|select
| l_orderkey,l_partkey,l_suppkey,l_linenumber,
| cast(l_quantity as decimal(${dataType.precision},${dataType.scale})),
| cast(l_extendedprice as decimal(${dataType.precision},${dataType.scale})),
| cast(l_discount as decimal(${dataType.precision},${dataType.scale})),
| cast(l_tax as decimal(${dataType.precision},${dataType.scale})),
| l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate,
| l_shipinstruct,l_shipmode,l_comment
|from ${tableName}_ori """.stripMargin
case "orders" =>
s"""
|select
| o_orderkey,o_custkey,o_orderstatus,
| cast(o_totalprice as decimal(${dataType.precision},${dataType.scale})),
| o_orderdate,
| o_orderpriority,o_clerk,o_shippriority,o_comment
|from ${tableName}_ori
|""".stripMargin
case "part" =>
s"""
|select
| p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,
| cast(p_retailprice as decimal(${dataType.precision},${dataType.scale})),
| p_comment
|from ${tableName}_ori
|""".stripMargin
case "partsupp" =>
s"""
|select
| ps_partkey,ps_suppkey,ps_availqty,
| cast(ps_supplycost as decimal(${dataType.precision},${dataType.scale})),
| ps_comment
|from ${tableName}_ori
|""".stripMargin
case "supplier" =>
s"""
|select
| s_suppkey,s_name,s_address,s_nationkey,s_phone,
| cast(s_acctbal as decimal(${dataType.precision},${dataType.scale})),s_comment
|from ${tableName}_ori
|""".stripMargin
case _ => s"select * from ${tableName}_ori"
}

val df = spark.sql(sql).toDF()
var salted_df: Option[DataFrame] = None
for (c <- df.schema) {
salted_df = Some((salted_df match {
case Some(x) => x
case None => df
}).withColumn(c.name, when(rand() < 0.1, null).otherwise(col(c.name))))
}

val currentSaltedTablePath = saltedTablesPath + "/" + tableName
val file = new File(currentSaltedTablePath)
if (file.exists()) {
file.delete()
}

salted_df.get.write.parquet(currentSaltedTablePath)
})
}

val customerData = saltedTablesPath + "/customer"
spark.sql(s"DROP TABLE IF EXISTS customer")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS customer (
| c_custkey bigint,
| c_name string,
| c_address string,
| c_nationkey bigint,
| c_phone string,
| c_acctbal decimal(${dataType.precision},${dataType.scale}),
| c_mktsegment string,
| c_comment string)
| USING PARQUET LOCATION '$customerData'
|""".stripMargin)

val lineitemData = saltedTablesPath + "/lineitem"
spark.sql(s"DROP TABLE IF EXISTS lineitem")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS lineitem (
| l_orderkey bigint,
| l_partkey bigint,
| l_suppkey bigint,
| l_linenumber bigint,
| l_quantity decimal(${dataType.precision},${dataType.scale}),
| l_extendedprice decimal(${dataType.precision},${dataType.scale}),
| l_discount decimal(${dataType.precision},${dataType.scale}),
| l_tax decimal(${dataType.precision},${dataType.scale}),
| l_returnflag string,
| l_linestatus string,
| l_shipdate date,
| l_commitdate date,
| l_receiptdate date,
| l_shipinstruct string,
| l_shipmode string,
| l_comment string)
| USING PARQUET LOCATION '$lineitemData'
|""".stripMargin)

val nationData = saltedTablesPath + "/nation"
spark.sql(s"DROP TABLE IF EXISTS nation")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS nation (
| n_nationkey bigint,
| n_name string,
| n_regionkey bigint,
| n_comment string)
| USING PARQUET LOCATION '$nationData'
|""".stripMargin)

val regionData = saltedTablesPath + "/region"
spark.sql(s"DROP TABLE IF EXISTS region")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS region (
| r_regionkey bigint,
| r_name string,
| r_comment string)
| USING PARQUET LOCATION '$regionData'
|""".stripMargin)

val ordersData = saltedTablesPath + "/orders"
spark.sql(s"DROP TABLE IF EXISTS orders")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS orders (
| o_orderkey bigint,
| o_custkey bigint,
| o_orderstatus string,
| o_totalprice decimal(${dataType.precision},${dataType.scale}),
| o_orderdate date,
| o_orderpriority string,
| o_clerk string,
| o_shippriority bigint,
| o_comment string)
| USING PARQUET LOCATION '$ordersData'
|""".stripMargin)

val partData = saltedTablesPath + "/part"
spark.sql(s"DROP TABLE IF EXISTS part")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS part (
| p_partkey bigint,
| p_name string,
| p_mfgr string,
| p_brand string,
| p_type string,
| p_size bigint,
| p_container string,
| p_retailprice decimal(${dataType.precision},${dataType.scale}),
| p_comment string)
| USING PARQUET LOCATION '$partData'
|""".stripMargin)

val partsuppData = saltedTablesPath + "/partsupp"
spark.sql(s"DROP TABLE IF EXISTS partsupp")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS partsupp (
| ps_partkey bigint,
| ps_suppkey bigint,
| ps_availqty bigint,
| ps_supplycost decimal(${dataType.precision},${dataType.scale}),
| ps_comment string)
| USING PARQUET LOCATION '$partsuppData'
|""".stripMargin)

val supplierData = saltedTablesPath + "/supplier"
spark.sql(s"DROP TABLE IF EXISTS supplier")
spark.sql(s"""
| CREATE TABLE IF NOT EXISTS supplier (
| s_suppkey bigint,
| s_name string,
| s_address string,
| s_nationkey bigint,
| s_phone string,
| s_acctbal decimal(${dataType.precision},${dataType.scale}),
| s_comment string)
| USING PARQUET LOCATION '$supplierData'
|""".stripMargin)

val result = spark
.sql(s"""
| show tables;
|""".stripMargin)
.collect()
assert(result.size == 16)
spark.sql(s"use default")
}

override protected def runTPCHQuery(
queryNum: Int,
tpchQueries: String = tpchQueries,
queriesResults: String = queriesResults,
compareResult: Boolean = true,
noFallBack: Boolean = true)(customCheck: DataFrame => Unit): Unit = {
decimalTPCHTables.foreach(
decimalType => {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
compareTPCHQueryAgainstVanillaSpark(queryNum, tpchQueries, customCheck, noFallBack)
spark.sql(s"use default")
})
}

test("TPCH Q20") {
runTPCHQuery(20)(_ => {})
}

test("fix decimal precision overflow") {
val sql =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
.set("spark.gluten.sql.enable.native.validation", "false")
.set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")
.set("spark.sql.warehouse.dir", warehouse)
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
/* .set("spark.sql.catalogImplementation", "hive")
.set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${
metaStorePathAbsolute + "/metastore_db"};create=true") */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ abstract class GlutenClickHouseTPCHAbstractSuite
.set("spark.gluten.sql.enable.native.validation", "false")
.set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")
.set("spark.sql.warehouse.dir", warehouse)
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
/* .set("spark.sql.catalogImplementation", "hive")
.set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${
metaStorePathAbsolute + "/metastore_db"};create=true") */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import scala.util.control.Breaks.breakable
Expand Down Expand Up @@ -322,4 +323,6 @@ object BackendSettings extends BackendSettingsApi {
override def requiredChildOrderingForWindow(): Boolean = true

override def staticPartitionWriteOnly(): Boolean = true

override def allowDecimalArithmetic: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ trait BackendSettingsApi {
/** Get the config prefix for each backend */
def getBackendConfigPrefix: String

def allowDecimalArithmetic: Boolean = true

def rescaleDecimalIntegralExpression(): Boolean = false

def shuffleSupportedCodec(): Set[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,13 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer(expr.children.head, attributeSeq),
expr)
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
if (!conf.decimalOperationsAllowPrecisionLoss) {
// PrecisionLoss=true: velox support / ch not support
// PrecisionLoss=false: velox not support / ch support
// TODO ch support PrecisionLoss=true
if (!BackendsApiManager.getSettings.allowDecimalArithmetic) {
throw new UnsupportedOperationException(
s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} false mode")
s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " +
s"${conf.decimalOperationsAllowPrecisionLoss} mode")
}
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
DecimalArithmeticUtil.rescaleLiteral(b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ trait GlutenTestsTrait extends GlutenTestsCommonTrait {
.config(GlutenConfig.GLUTEN_LIB_PATH, SystemParameters.getClickHouseLibPath)
.config("spark.unsafe.exceptionOnMemoryLeak", "true")
.config(GlutenConfig.UT_STATISTIC.key, "true")
.config("spark.sql.decimalOperations.allowPrecisionLoss", "false")
.getOrCreate()
} else {
sparkBuilder
Expand Down
Loading