Skip to content

Commit

Permalink
Clean up the conversion from Substrait type to Velox type
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Oct 23, 2023
1 parent 5cec4f9 commit ef8fa45
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 478 deletions.
1 change: 0 additions & 1 deletion cpp/velox/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ set(VELOX_SRCS
substrait/SubstraitToVeloxPlan.cc
substrait/SubstraitToVeloxPlanValidator.cc
substrait/VariantToVectorConverter.cc
substrait/TypeUtils.cc
substrait/SubstraitExtensionCollector.cc
substrait/VeloxSubstraitSignature.cc
substrait/VeloxToSubstraitExpr.cc
Expand Down
8 changes: 2 additions & 6 deletions cpp/velox/compute/VeloxPlanConverter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ void VeloxPlanConverter::setInputPlanNode(const ::substrait::ReadRel& sread) {

// Get the input schema of this iterator.
uint64_t colNum = 0;
std::vector<std::shared_ptr<SubstraitParser::SubstraitType>> subTypeList;
std::vector<velox::TypePtr> veloxTypeList;
if (sread.has_base_schema()) {
const auto& baseSchema = sread.base_schema();
// Input names is not used. Instead, new input/output names will be created
// because the ValueStreamNode in Velox does not support name change.
colNum = baseSchema.names().size();
subTypeList = SubstraitParser::parseNamedStruct(baseSchema);
veloxTypeList = SubstraitParser::parseNamedStruct(baseSchema);
}

std::vector<std::string> outNames;
Expand All @@ -140,10 +140,6 @@ void VeloxPlanConverter::setInputPlanNode(const ::substrait::ReadRel& sread) {
outNames.emplace_back(colName);
}

std::vector<velox::TypePtr> veloxTypeList;
for (auto subType : subTypeList) {
veloxTypeList.push_back(toVeloxType(subType->type));
}
auto outputType = ROW(std::move(outNames), std::move(veloxTypeList));
auto vectorStream = std::make_shared<RowVectorStream>(pool_, std::move(inputIters_[iterIdx]), outputType);
auto valuesNode = std::make_shared<ValueStreamNode>(nextPlanNodeId(), outputType, std::move(vectorStream));
Expand Down
231 changes: 81 additions & 150 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,166 +19,85 @@
#include "TypeUtils.h"
#include "velox/common/base/Exceptions.h"

#include "VeloxSubstraitSignature.h"

namespace gluten {

SubstraitParser::SubstraitType SubstraitParser::parseType(const ::substrait::Type& substraitType) {
// The used type names should be aligned with those in Velox.
std::string typeName;
::substrait::Type_Nullability nullability;
TypePtr SubstraitParser::parseType(const ::substrait::Type& substraitType, bool asLowerCase) {
switch (substraitType.kind_case()) {
case ::substrait::Type::KindCase::kBool: {
typeName = "BOOLEAN";
nullability = substraitType.bool_().nullability();
break;
}
case ::substrait::Type::KindCase::kI8: {
typeName = "TINYINT";
nullability = substraitType.i8().nullability();
break;
}
case ::substrait::Type::KindCase::kI16: {
typeName = "SMALLINT";
nullability = substraitType.i16().nullability();
break;
}
case ::substrait::Type::KindCase::kI32: {
typeName = "INTEGER";
nullability = substraitType.i32().nullability();
break;
}
case ::substrait::Type::KindCase::kI64: {
typeName = "BIGINT";
nullability = substraitType.i64().nullability();
break;
}
case ::substrait::Type::KindCase::kFp32: {
typeName = "REAL";
nullability = substraitType.fp32().nullability();
break;
}
case ::substrait::Type::KindCase::kFp64: {
typeName = "DOUBLE";
nullability = substraitType.fp64().nullability();
break;
}
case ::substrait::Type::KindCase::kString: {
typeName = "VARCHAR";
nullability = substraitType.string().nullability();
break;
}
case ::substrait::Type::KindCase::kBinary: {
typeName = "VARBINARY";
nullability = substraitType.string().nullability();
break;
}
case ::substrait::Type::KindCase::kBool:
return BOOLEAN();
case ::substrait::Type::KindCase::kI8:
return TINYINT();
case ::substrait::Type::KindCase::kI16:
return SMALLINT();
case ::substrait::Type::KindCase::kI32:
return INTEGER();
case ::substrait::Type::KindCase::kI64:
return BIGINT();
case ::substrait::Type::KindCase::kFp32:
return REAL();
case ::substrait::Type::KindCase::kFp64:
return DOUBLE();
case ::substrait::Type::KindCase::kString:
return VARCHAR();
case ::substrait::Type::KindCase::kBinary:
return VARBINARY();
case ::substrait::Type::KindCase::kStruct: {
// The type name of struct is in the format of:
// ROW<type0:name0,type1:name1,ROW<type2:name2>,...typen:namen>.
typeName = "ROW<";
const auto& substraitStruct = substraitType.struct_();
const auto& structTypes = substraitStruct.types();
const auto& structNames = substraitStruct.names();
bool nameProvided = structTypes.size() == structNames.size();
std::vector<TypePtr> types;
std::vector<std::string> names;
for (int i = 0; i < structTypes.size(); i++) {
if (i > 0) {
typeName += ',';
}
typeName += parseType(structTypes[i]).type;
// Struct names could be empty.
if (nameProvided) {
typeName += (':' + structNames[i]);
types.emplace_back(parseType(structTypes[i]));
std::string fieldName = nameProvided ? structNames[i] : "col_" + std::to_string(i);
if (asLowerCase) {
folly::toLowerAscii(fieldName);
}
names.emplace_back(fieldName);
}
typeName += '>';
nullability = substraitType.struct_().nullability();
break;
return ROW(std::move(names), std::move(types));
}
case ::substrait::Type::KindCase::kList: {
// The type name of list is in the format of: ARRAY<T>.
const auto& sList = substraitType.list();
const auto& sType = sList.type();
typeName = "ARRAY<" + parseType(sType).type + ">";
nullability = substraitType.list().nullability();
break;
const auto& fieldType = substraitType.list().type();
return ARRAY(parseType(fieldType));
}
case ::substrait::Type::KindCase::kMap: {
// The type name of map is in the format of: MAP<K,V>.
const auto& sMap = substraitType.map();
const auto& keyType = sMap.key();
const auto& valueType = sMap.value();
typeName = "MAP<" + parseType(keyType).type + "," + parseType(valueType).type + ">";
nullability = substraitType.map().nullability();
break;
return MAP(parseType(keyType), parseType(valueType));
}
case ::substrait::Type::KindCase::kUserDefined: {
case ::substrait::Type::KindCase::kUserDefined:
// We only support UNKNOWN type to handle the null literal whose type is
// not known.
VELOX_CHECK_EQ(substraitType.user_defined().type_reference(), 0);
typeName = "UNKNOWN";
nullability = substraitType.string().nullability();
break;
}
case ::substrait::Type::KindCase::kDate: {
typeName = "DATE";
nullability = substraitType.date().nullability();
break;
}
case ::substrait::Type::KindCase::kTimestamp: {
typeName = "TIMESTAMP";
nullability = substraitType.timestamp().nullability();
break;
}
return UNKNOWN();
case ::substrait::Type::KindCase::kDate:
return DATE();
case ::substrait::Type::KindCase::kTimestamp:
return TIMESTAMP();
case ::substrait::Type::KindCase::kDecimal: {
auto precision = substraitType.decimal().precision();
auto scale = substraitType.decimal().scale();
if (precision <= 18) {
typeName = "SHORT_DECIMAL<" + std::to_string(precision) + "," + std::to_string(scale) + ">";
} else {
typeName = "HUGEINT<" + std::to_string(precision) + "," + std::to_string(scale) + ">";
}
nullability = substraitType.decimal().nullability();
break;
return DECIMAL(precision, scale);
}
default:
VELOX_NYI("Parsing for Substrait type not supported: {}", substraitType.DebugString());
}

bool nullable;
switch (nullability) {
case ::substrait::Type_Nullability::Type_Nullability_NULLABILITY_UNSPECIFIED:
nullable = true;
break;
case ::substrait::Type_Nullability::Type_Nullability_NULLABILITY_NULLABLE:
nullable = true;
break;
case ::substrait::Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED:
nullable = false;
break;
default:
VELOX_NYI("Substrait parsing for nullability {} not supported.", nullability);
}
return SubstraitType{typeName, nullable};
}

std::string SubstraitParser::parseType(const std::string& substraitType) {
auto it = typeMap_.find(substraitType);
if (it == typeMap_.end()) {
VELOX_NYI("Substrait parsing for type {} not supported.", substraitType);
}
return it->second;
};

std::vector<std::shared_ptr<SubstraitParser::SubstraitType>> SubstraitParser::parseNamedStruct(
const ::substrait::NamedStruct& namedStruct) {
std::vector<TypePtr> SubstraitParser::parseNamedStruct(const ::substrait::NamedStruct& namedStruct, bool asLowerCase) {
// Note that "names" are not used.

// Parse Struct.
const auto& substraitStruct = namedStruct.struct_();
const auto& substraitTypes = substraitStruct.types();
std::vector<std::shared_ptr<SubstraitParser::SubstraitType>> substraitTypeList;
std::vector<TypePtr> substraitTypeList;
substraitTypeList.reserve(substraitTypes.size());
for (const auto& type : substraitTypes) {
substraitTypeList.emplace_back(std::make_shared<SubstraitParser::SubstraitType>(parseType(type)));
substraitTypeList.emplace_back(parseType(type, asLowerCase));
}
return substraitTypeList;
}
Expand Down Expand Up @@ -262,55 +181,57 @@ std::string SubstraitParser::findFunctionSpec(
return x->second;
}

std::string SubstraitParser::getSubFunctionName(const std::string& subFuncSpec) {
// Get the position of ":" in the function name.
std::size_t pos = subFuncSpec.find(":");
// TODO Refactor using Bison.
std::string SubstraitParser::getNameBeforeDelimiter(const std::string& signature, const std::string& delimiter) {
std::size_t pos = signature.find(delimiter);
if (pos == std::string::npos) {
return subFuncSpec;
return signature;
}
return subFuncSpec.substr(0, pos);
return signature.substr(0, pos);
}

void SubstraitParser::getSubFunctionTypes(const std::string& subFuncSpec, std::vector<std::string>& types) {
std::vector<std::string> SubstraitParser::getSubFunctionTypes(const std::string& substraitFunction) {
// Get the position of ":" in the function name.
std::size_t pos = subFuncSpec.find(":");
size_t pos = substraitFunction.find(":");
// Get the parameter types.
std::string funcTypes;
if (pos == std::string::npos) {
funcTypes = subFuncSpec;
} else {
if (pos == subFuncSpec.size() - 1) {
return;
}
funcTypes = subFuncSpec.substr(pos + 1);
std::vector<std::string> types;
if (pos == std::string::npos || pos == substraitFunction.size() - 1) {
return types;
}
// Split the types with delimiter.
std::string delimiter = "_";
while ((pos = funcTypes.find(delimiter)) != std::string::npos) {
auto type = funcTypes.substr(0, pos);
if (type != "opt" && type != "req") {
types.emplace_back(type);
// Extract input types with delimiter.
for (;;) {
const size_t endPos = substraitFunction.find("_", pos + 1);
if (endPos == std::string::npos) {
std::string typeName = substraitFunction.substr(pos + 1);
if (typeName != "opt" && typeName != "req") {
types.emplace_back(typeName);
}
break;
}
funcTypes.erase(0, pos + delimiter.length());

const std::string typeName = substraitFunction.substr(pos + 1, endPos - pos - 1);
if (typeName != "opt" && typeName != "req") {
types.emplace_back(typeName);
}
pos = endPos;
}
types.emplace_back(funcTypes);
return types;
}

std::string SubstraitParser::findVeloxFunction(
const std::unordered_map<uint64_t, std::string>& functionMap,
uint64_t id) {
std::string funcSpec = findFunctionSpec(functionMap, id);
std::string_view funcName = getNameBeforeDelimiter(funcSpec, ":");
std::vector<std::string> types;
getSubFunctionTypes(funcSpec, types);
std::string funcName = getNameBeforeDelimiter(funcSpec);
std::vector<std::string> types = getSubFunctionTypes(funcSpec);
bool isDecimal = false;
for (auto& type : types) {
if (type.find("dec") != std::string::npos) {
isDecimal = true;
break;
}
}
return mapToVeloxFunction({funcName.begin(), funcName.end()}, isDecimal);
return mapToVeloxFunction(funcName, isDecimal);
}

std::string SubstraitParser::mapToVeloxFunction(const std::string& substraitFunction, bool isDecimal) {
Expand Down Expand Up @@ -347,6 +268,16 @@ bool SubstraitParser::configSetInOptimization(
return false;
}

std::vector<TypePtr> SubstraitParser::sigToTypes(const std::string& signature) {
std::vector<std::string> typeStrs = SubstraitParser::getSubFunctionTypes(signature);
std::vector<TypePtr> types;
types.reserve(typeStrs.size());
for (const auto& typeStr : typeStrs) {
types.emplace_back(VeloxSubstraitSignature::fromSubstraitSignature(typeStr));
}
return types;
}

std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunctionMap_ = {
{"is_not_null", "isnotnull"}, /*Spark functions.*/
{"is_null", "isnull"},
Expand Down
Loading

0 comments on commit ef8fa45

Please sign in to comment.