Skip to content

Commit

Permalink
Add more validations for functions (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored Apr 21, 2023
1 parent f265f03 commit 10addd0
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 16 deletions.
5 changes: 3 additions & 2 deletions velox/common/file/FileSystems.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ std::shared_ptr<FileSystem> getFileSystem(
// and a lambda that generates the actual file system.
void registerFileSystem(
std::function<bool(std::string_view)> schemeMatcher,
std::function<std::shared_ptr<FileSystem>(std::shared_ptr<const Config>, std::string_view)>
fileSystemGenerator);
std::function<std::shared_ptr<FileSystem>(
std::shared_ptr<const Config>,
std::string_view)> fileSystemGenerator);

// Register the local filesystem.
void registerLocalFileSystem();
Expand Down
15 changes: 5 additions & 10 deletions velox/functions/sparksql/aggregates/DecimalAvgAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ class DecimalAverageAggregate : public exec::Aggregate {
if ((accumulator->overflow == 1 && accumulator->sum < 0) ||
(accumulator->overflow == -1 && accumulator->sum > 0)) {
sum = static_cast<int128_t>(
DecimalUtil::kOverflowMultiplier * accumulator->overflow +
accumulator->sum);
DecimalUtil::kOverflowMultiplier * accumulator->overflow +
accumulator->sum);
} else {
VELOX_CHECK(
accumulator->overflow == 0,
"overflow: decimal avg struct overflow not eq 0");
accumulator->overflow == 0,
"overflow: decimal avg struct overflow not eq 0");
}

auto [resultPrecision, resultScale] =
Expand Down Expand Up @@ -336,12 +336,7 @@ class DecimalAverageAggregate : public exec::Aggregate {
UnscaledLongDecimal,
UnscaledLongDecimal,
UnscaledLongDecimal>(
avg,
UnscaledLongDecimal(sum),
countDecimal,
false,
sumRescale,
0);
avg, UnscaledLongDecimal(sum), countDecimal, false, sumRescale, 0);
}
auto castedAvg =
DecimalUtil::rescaleWithRoundUp<UnscaledLongDecimal, TResultType>(
Expand Down
3 changes: 2 additions & 1 deletion velox/row/UnsafeRowDeserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,8 @@ struct UnsafeRowDynamicVectorDeserializer {
static VectorPtr convertPrimitiveIteratorsToVectors(
std::vector<DataIteratorPtr>::iterator dataIterators,
memory::MemoryPool* pool,
size_t numIteratorsToProcess,int32_t numFields = 1,
size_t numIteratorsToProcess,
int32_t numFields = 1,
int fieldsIdx = 0) {
TypePtr type = (*dataIterators)->type();
assert(type->isPrimitiveType());
Expand Down
6 changes: 5 additions & 1 deletion velox/row/tests/UnsafeRowBatchDeserializerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,11 @@ class UnsafeRowComplexBatchDeserializerTests
return StringView::makeInline("str" + std::to_string(row + index));
});
return makeRowVector(
{intVector, stringVector, decimalVector, intArrayVector, stringArrayVector});
{intVector,
stringVector,
decimalVector,
intArrayVector,
stringArrayVector});
}

std::shared_ptr<memory::MemoryPool> pool_ = memory::getDefaultMemoryPool();
Expand Down
48 changes: 46 additions & 2 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,44 @@ bool SubstraitToVeloxPlanValidator::validateRound(
bool SubstraitToVeloxPlanValidator::validateScalarFunction(
const ::substrait::Expression::ScalarFunction& scalarFunction,
const RowTypePtr& inputType) {
const auto& veloxFunction = subParser_->findVeloxFunction(
const auto& function = subParser_->findSubstraitFuncSpec(
planConverter_->getFunctionMap(), scalarFunction.function_reference());
if (veloxFunction == "round") {
const auto& name = subParser_->getSubFunctionName(function);
std::vector<std::string> types;
subParser_->getSubFunctionTypes(function, types);
if (name == "round") {
return validateRound(scalarFunction, inputType);
}
if (name == "char_length") {
VELOX_CHECK(types.size() == 1);
if (types[0] == "vbin") {
VLOG(1) << "Binary type is not supported in " << name << ".";
return false;
}
}
std::unordered_set<std::string> functions = {
"regexp_replace",
"split",
"split_part",
"factorial",
"concat_ws",
"rand",
"json_array_length",
"from_unixtime",
"to_unix_timestamp",
"unix_timestamp",
"repeat",
"translate",
"add_months",
"date_format",
"trunc",
"sequence",
"posexplode"};
if (functions.find(name) != functions.end()) {
VLOG(1) << "Function is not supported: " << name << ".";
return false;
}

return true;
}

Expand Down Expand Up @@ -330,6 +363,17 @@ bool SubstraitToVeloxPlanValidator::validate(
}
}

// Validate supported aggregate functions.
std::unordered_set<std::string> unsupportedFuncs = {"collect_list"};
for (const auto& funcSpec : funcSpecs) {
auto funcName = subParser_->getSubFunctionName(funcSpec);
if (unsupportedFuncs.find(funcName) != unsupportedFuncs.end()) {
std::cout << "Validation failed due to " << funcName
<< " was not supported in WindowRel." << std::endl;
return false;
}
}

// Validate groupby expression
const auto& groupByExprs = windowRel.partition_expressions();
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions;
Expand Down

0 comments on commit 10addd0

Please sign in to comment.