From 10addd03c1b7f36a9809d41f7dbf4e3a758f653f Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Fri, 21 Apr 2023 11:15:31 +0800 Subject: [PATCH] Add more validations for functions (#198) --- velox/common/file/FileSystems.h | 5 +- .../sparksql/aggregates/DecimalAvgAggregate.h | 15 ++---- velox/row/UnsafeRowDeserializer.h | 3 +- .../tests/UnsafeRowBatchDeserializerTest.cpp | 6 ++- .../SubstraitToVeloxPlanValidator.cpp | 48 ++++++++++++++++++- 5 files changed, 61 insertions(+), 16 deletions(-) diff --git a/velox/common/file/FileSystems.h b/velox/common/file/FileSystems.h index cb867dd3f73c..4e997dba8968 100644 --- a/velox/common/file/FileSystems.h +++ b/velox/common/file/FileSystems.h @@ -89,8 +89,9 @@ std::shared_ptr getFileSystem( // and a lambda that generates the actual file system. void registerFileSystem( std::function schemeMatcher, - std::function(std::shared_ptr, std::string_view)> - fileSystemGenerator); + std::function( + std::shared_ptr, + std::string_view)> fileSystemGenerator); // Register the local filesystem. void registerLocalFileSystem(); diff --git a/velox/functions/sparksql/aggregates/DecimalAvgAggregate.h b/velox/functions/sparksql/aggregates/DecimalAvgAggregate.h index 99ff5c0a8d14..bc8a082fe801 100644 --- a/velox/functions/sparksql/aggregates/DecimalAvgAggregate.h +++ b/velox/functions/sparksql/aggregates/DecimalAvgAggregate.h @@ -297,12 +297,12 @@ class DecimalAverageAggregate : public exec::Aggregate { if ((accumulator->overflow == 1 && accumulator->sum < 0) || (accumulator->overflow == -1 && accumulator->sum > 0)) { sum = static_cast( - DecimalUtil::kOverflowMultiplier * accumulator->overflow + - accumulator->sum); + DecimalUtil::kOverflowMultiplier * accumulator->overflow + + accumulator->sum); } else { VELOX_CHECK( - accumulator->overflow == 0, - "overflow: decimal avg struct overflow not eq 0"); + accumulator->overflow == 0, + "overflow: decimal avg struct overflow not eq 0"); } auto [resultPrecision, resultScale] = @@ -336,12 +336,7 @@ class DecimalAverageAggregate : public exec::Aggregate { UnscaledLongDecimal, UnscaledLongDecimal, UnscaledLongDecimal>( - avg, - UnscaledLongDecimal(sum), - countDecimal, - false, - sumRescale, - 0); + avg, UnscaledLongDecimal(sum), countDecimal, false, sumRescale, 0); } auto castedAvg = DecimalUtil::rescaleWithRoundUp( diff --git a/velox/row/UnsafeRowDeserializer.h b/velox/row/UnsafeRowDeserializer.h index e7e7fac4cfbe..8b180efcec00 100644 --- a/velox/row/UnsafeRowDeserializer.h +++ b/velox/row/UnsafeRowDeserializer.h @@ -1037,7 +1037,8 @@ struct UnsafeRowDynamicVectorDeserializer { static VectorPtr convertPrimitiveIteratorsToVectors( std::vector::iterator dataIterators, memory::MemoryPool* pool, - size_t numIteratorsToProcess,int32_t numFields = 1, + size_t numIteratorsToProcess, + int32_t numFields = 1, int fieldsIdx = 0) { TypePtr type = (*dataIterators)->type(); assert(type->isPrimitiveType()); diff --git a/velox/row/tests/UnsafeRowBatchDeserializerTest.cpp b/velox/row/tests/UnsafeRowBatchDeserializerTest.cpp index 3d524c1005d4..47a8f1f2909c 100644 --- a/velox/row/tests/UnsafeRowBatchDeserializerTest.cpp +++ b/velox/row/tests/UnsafeRowBatchDeserializerTest.cpp @@ -665,7 +665,11 @@ class UnsafeRowComplexBatchDeserializerTests return StringView::makeInline("str" + std::to_string(row + index)); }); return makeRowVector( - {intVector, stringVector, decimalVector, intArrayVector, stringArrayVector}); + {intVector, + stringVector, + decimalVector, + intArrayVector, + stringArrayVector}); } std::shared_ptr pool_ = memory::getDefaultMemoryPool(); diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 8597ec33579a..411fda08acc4 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -95,11 +95,44 @@ bool SubstraitToVeloxPlanValidator::validateRound( bool SubstraitToVeloxPlanValidator::validateScalarFunction( const ::substrait::Expression::ScalarFunction& scalarFunction, const RowTypePtr& inputType) { - const auto& veloxFunction = subParser_->findVeloxFunction( + const auto& function = subParser_->findSubstraitFuncSpec( planConverter_->getFunctionMap(), scalarFunction.function_reference()); - if (veloxFunction == "round") { + const auto& name = subParser_->getSubFunctionName(function); + std::vector types; + subParser_->getSubFunctionTypes(function, types); + if (name == "round") { return validateRound(scalarFunction, inputType); } + if (name == "char_length") { + VELOX_CHECK(types.size() == 1); + if (types[0] == "vbin") { + VLOG(1) << "Binary type is not supported in " << name << "."; + return false; + } + } + std::unordered_set functions = { + "regexp_replace", + "split", + "split_part", + "factorial", + "concat_ws", + "rand", + "json_array_length", + "from_unixtime", + "to_unix_timestamp", + "unix_timestamp", + "repeat", + "translate", + "add_months", + "date_format", + "trunc", + "sequence", + "posexplode"}; + if (functions.find(name) != functions.end()) { + VLOG(1) << "Function is not supported: " << name << "."; + return false; + } + return true; } @@ -330,6 +363,17 @@ bool SubstraitToVeloxPlanValidator::validate( } } + // Validate supported aggregate functions. + std::unordered_set unsupportedFuncs = {"collect_list"}; + for (const auto& funcSpec : funcSpecs) { + auto funcName = subParser_->getSubFunctionName(funcSpec); + if (unsupportedFuncs.find(funcName) != unsupportedFuncs.end()) { + std::cout << "Validation failed due to " << funcName + << " was not supported in WindowRel." << std::endl; + return false; + } + } + // Validate groupby expression const auto& groupByExprs = windowRel.partition_expressions(); std::vector> expressions;