Skip to content

Commit

Permalink
adding param allowPrecisionLoss
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed May 6, 2024
1 parent aa8edf9 commit 201adf1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 13 deletions.
7 changes: 7 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class QueryConfig {
static constexpr const char* kCastMatchStructByName =
"cast_match_struct_by_name";

// This flags forces to bound the decimal precision.
static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss";

/// Used for backpressure to block local exchange producers when the local
/// exchange buffer reaches or exceeds this size.
static constexpr const char* kMaxLocalExchangeBufferSize =
Expand Down Expand Up @@ -474,6 +477,10 @@ class QueryConfig {
return get<bool>(kCastMatchStructByName, false);
}

bool isAllowPrecisionLoss() const {
return get<bool>(kAllowPrecisionLoss, true);
}

bool adjustTimestampToTimezone() const {
return get<bool>(kAdjustTimestampToTimezone, false);
}
Expand Down
46 changes: 33 additions & 13 deletions velox/functions/sparksql/DecimalArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,14 @@ class Addition {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
uint8_t bScale,
bool allowPrecisionLoss) {
auto precision = std::max(aPrecision - aScale, bPrecision - bScale) +
std::max(aScale, bScale) + 1;
auto scale = std::max(aScale, bScale);
return DecimalUtil::adjustPrecisionScale(precision, scale);
return allowPrecisionLoss
? DecimalUtil::adjustPrecisionScale(precision, scale)
: DecimalUtil::bounded(precision, scale);
}
};

Expand Down Expand Up @@ -464,9 +467,10 @@ class Subtraction {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
uint8_t bScale,
bool allowPrecisionLoss) {
return Addition::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
}
};

Expand Down Expand Up @@ -566,9 +570,12 @@ class Multiply {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
return DecimalUtil::adjustPrecisionScale(
aPrecision + bPrecision + 1, aScale + bScale);
uint8_t bScale,
const bool allowPrecisionLoss) {
return allowPrecisionLoss
? DecimalUtil::adjustPrecisionScale(
aPrecision + bPrecision + 1, aScale + bScale)
: DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale);
}

private:
Expand Down Expand Up @@ -616,10 +623,22 @@ class Divide {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return DecimalUtil::adjustPrecisionScale(precision, scale);
uint8_t bScale,
bool allowPrecisionLoss) {
if (allowPrecisionLoss) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return DecimalUtil::adjustPrecisionScale(precision, scale);
} else {
auto intDig = std::min(38, aPrecision - aScale + bScale);
auto decDig = std::min(38, std::max(6, aScale + bPrecision + 1));
auto diff = (intDig + decDig) - 38;
if (diff > 0) {
decDig -= diff / 2 + 1;
intDig = 38 - decDig;
}
return DecimalUtil::bounded(intDig + decDig, decDig);
}
}
};

Expand Down Expand Up @@ -689,13 +708,14 @@ template <typename Operation>
std::shared_ptr<exec::VectorFunction> createDecimalFunction(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
const core::QueryConfig& config) {
const auto& aType = inputArgs[0].type;
const auto& bType = inputArgs[1].type;
const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
const bool allowPrecisionLoss = config.isAllowPrecisionLoss();
const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
const uint8_t aRescale =
Operation::computeRescaleFactor(aScale, bScale, rScale);
const uint8_t bRescale =
Expand Down
10 changes: 10 additions & 0 deletions velox/functions/sparksql/DecimalUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ class DecimalUtil {
}
}

/// This method is used when
/// `spark.sql.decimalOperations.allowPrecisionLoss` is set to false.
inline static std::pair<uint8_t, uint8_t> bounded(
uint8_t rPrecision,
uint8_t rScale) {
return {
std::min(static_cast<int32_t>(rPrecision), 38),
std::min(static_cast<int32_t>(rScale), 38)};
}

/// @brief Convert int256 value to int64 or int128, set overflow to true if
/// value cannot convert to specific type.
/// @return The converted value.
Expand Down

0 comments on commit 201adf1

Please sign in to comment.