diff --git a/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp b/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp index f5c349ba4e63..9111a0fd18f6 100644 --- a/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp +++ b/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp @@ -229,9 +229,8 @@ struct CorrResultAccessor { } static double result(const CorrAccumulator& accumulator) { - double stddevX = std::sqrt(accumulator.m2X()); - double stddevY = std::sqrt(accumulator.m2Y()); - return accumulator.c2() / stddevX / stddevY; + // Need to modify the calculation order to maintain the same accuracy as spark + return accumulator.c2() / std::sqrt(accumulator.m2X() * accumulator.m2Y()); } }; diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index fc57d2e400a3..94f0c2a9e436 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -789,7 +789,10 @@ bool SubstraitToVeloxPlanValidator::validate( "var_samp", "var_pop", "bitwise_and_agg", - "bitwise_or_agg"}; + "bitwise_or_agg", + "corr", + "covar_pop", + "covar_samp"}; for (const auto& funcSpec : funcSpecs) { auto funcName = subParser_->getSubFunctionName(funcSpec); if (supportedFuncs.find(funcName) == supportedFuncs.end()) {