Skip to content

Commit

Permalink
adding param allowPrecisionLoss
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed Feb 28, 2024
1 parent bdd09e4 commit 9cbea26
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 18 deletions.
7 changes: 7 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class QueryConfig {
static constexpr const char* kCastMatchStructByName =
"cast_match_struct_by_name";

// This flags forces to bound the decimal precision.
static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss";

/// Used for backpressure to block local exchange producers when the local
/// exchange buffer reaches or exceeds this size.
static constexpr const char* kMaxLocalExchangeBufferSize =
Expand Down Expand Up @@ -496,6 +499,10 @@ class QueryConfig {
return get<bool>(kCastMatchStructByName, false);
}

bool isAllowPrecisionLoss() const {
return get<bool>(kAllowPrecisionLoss, true);
}

bool codegenEnabled() const {
return get<bool>(kCodegenEnabled, false);
}
Expand Down
56 changes: 38 additions & 18 deletions velox/functions/sparksql/DecimalArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,14 +413,17 @@ class Addition {
}

inline static std::pair<uint8_t, uint8_t> computeResultPrecisionScale(
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
const uint8_t aPrecision,
const uint8_t aScale,
const uint8_t bPrecision,
const uint8_t bScale,
const bool allowPrecisionLoss) {
auto precision = std::max(aPrecision - aScale, bPrecision - bScale) +
std::max(aScale, bScale) + 1;
auto scale = std::max(aScale, bScale);
return DecimalUtil::adjustPrecisionScale(precision, scale);
return allowPrecisionLoss
? DecimalUtil::adjustPrecisionScale(precision, scale)
: DecimalUtil::bounded(precision, scale);
}
};

Expand Down Expand Up @@ -461,12 +464,13 @@ class Subtraction {
}

inline static std::pair<uint8_t, uint8_t> computeResultPrecisionScale(
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
const uint8_t aPrecision,
const uint8_t aScale,
const uint8_t bPrecision,
const uint8_t bScale,
const bool allowPrecisionLoss) {
return Addition::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
}
};

Expand Down Expand Up @@ -566,9 +570,12 @@ class Multiply {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
return DecimalUtil::adjustPrecisionScale(
aPrecision + bPrecision + 1, aScale + bScale);
uint8_t bScale,
const bool allowPrecisionLoss) {
return allowPrecisionLoss
? DecimalUtil::adjustPrecisionScale(
aPrecision + bPrecision + 1, aScale + bScale)
: DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale);
}

private:
Expand Down Expand Up @@ -616,10 +623,22 @@ class Divide {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return DecimalUtil::adjustPrecisionScale(precision, scale);
uint8_t bScale,
bool allowPrecisionLoss) {
if (allowPrecisionLoss) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return DecimalUtil::adjustPrecisionScale(precision, scale);
} else {
auto intDig = std::min(38, aPrecision - aScale + bScale);
auto decDig =
std::min(38, std::max(6, aScale + bPrecision + 1)) auto diff =
(intDig + decDig) - 38;
if (diff > 0) {
decDig -= diff / 2 + 1 intDig = 38 - decDig
}
return DecimalUtil::bounded(intDig + decDig, decDig);
}
}
};

Expand Down Expand Up @@ -694,8 +713,9 @@ std::shared_ptr<exec::VectorFunction> createDecimalFunction(
const auto& bType = inputArgs[1].type;
const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
const bool allowPrecisionLoss = config.isAllowPrecisionLoss();
const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
const uint8_t aRescale =
Operation::computeRescaleFactor(aScale, bScale, rScale);
const uint8_t bRescale =
Expand Down
6 changes: 6 additions & 0 deletions velox/functions/sparksql/DecimalUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class DecimalUtil {
}
}

inline static std::pair<uint8_t, uint8_t> bounded(
const uint8_t rPrecision,
const uint8_t rScale) {
return {std::min(rPrecision, 38), std::min(rScale, 38)};
}

/// @brief Convert int256 value to int64 or int128, set overflow to true if
/// value cannot convert to specific type.
/// @return The converted value.
Expand Down

0 comments on commit 9cbea26

Please sign in to comment.