Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Apr 10, 2024
1 parent e9cd7a6 commit 8203b08
Show file tree
Hide file tree
Showing 7 changed files with 515 additions and 1 deletion.
15 changes: 15 additions & 0 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ Unless specified otherwise, all functions return NULL if at least one of the arg
If ``n < 0``, the result is an empty string.
If ``n >= 256``, the result is equivalent to chr(``n % 256``).

.. spark:function:: concat_ws(separator, [string]/[array<string>], ...) -> varchar
Returns the concatenation for ``string`` & all elements in ``array<string>``, separated by
``separator``. Only accepts constant ``separator``. It takes variable number of remaining
arguments. And ``string`` & ``array<string>`` can be used in combination. If ``separator``
is NULL, returns NULL, regardless of the following inputs. If only ``separator`` (not a
NULL) is provided or all remaining inputs are NULL, returns an empty string. ::

SELECT concat_ws('~', 'a', 'b', 'c'); -- 'a~b~c'
SELECT concat_ws('~', ['a', 'b', 'c'], ['d']); -- 'a~b~c~d'
SELECT concat_ws('~', 'a', ['b', 'c']); -- 'a~b~c'
SELECT concat_ws(NULL, 'a'); -- NULL
SELECT concat_ws('~'); -- ''
SELECT concat_ws('~', NULL, NULL); -- ''

.. spark:function:: contains(left, right) -> boolean
Returns true if 'right' is found in 'left'. Otherwise, returns false. ::
Expand Down
6 changes: 5 additions & 1 deletion velox/expression/tests/SparkExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ int main(int argc, char** argv) {
"chr",
"replace",
"might_contain",
"unix_timestamp"};
"unix_timestamp",
// Skip concat_ws as it triggers a test failure due to an incorrect
// expression generation from fuzzer:
// https://github.com/facebookincubator/velox/issues/6590
"concat_ws"};

// Required by spark_partition_id function.
std::unordered_map<std::string, std::string> queryConfigs = {
Expand Down
5 changes: 5 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ void registerFunctions(const std::string& prefix) {
prefix + "length", lengthSignatures(), makeLength);
registerFunction<SubstringIndexFunction, Varchar, Varchar, Varchar, int32_t>(
{prefix + "substring_index"});
exec::registerStatefulVectorFunction(
prefix + "concat_ws",
concatWsSignatures(),
makeConcatWs,
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build());

registerFunction<Md5Function, Varchar, Varbinary>({prefix + "md5"});
registerFunction<Sha1HexStringFunction, Varchar, Varbinary>(
Expand Down
236 changes: 236 additions & 0 deletions velox/functions/sparksql/String.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,202 @@ class Length : public exec::VectorFunction {
}
};

void doApply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
const std::string& separator,
FlatVector<StringView>& flatResult) {
std::vector<column_index_t> argMapping;
std::vector<std::string> constantStrings;
auto numArgs = args.size();
// Save constant values to constantStrings_.
// Identify and combine consecutive constant inputs.
argMapping.reserve(numArgs - 1);
constantStrings.reserve(numArgs - 1);
// For each array arg, save rawSizes, rawOffsets, indices, and elements
// BaseBector.
std::vector<const vector_size_t*> rawSizesVector;
std::vector<const vector_size_t*> rawOffsetsVector;
std::vector<const vector_size_t*> indicesVector;
std::vector<DecodedVector> decodedVectors;

for (auto i = 1; i < numArgs; ++i) {
if (args[i] && args[i]->typeKind() == TypeKind::ARRAY) {
exec::LocalDecodedVector arrayHolder(context, *args[i], rows);
auto& arrayDecoded = *arrayHolder.get();
auto baseArray = arrayDecoded.base()->as<ArrayVector>();
rawSizesVector.push_back(baseArray->rawSizes());
rawOffsetsVector.push_back(baseArray->rawOffsets());
indicesVector.push_back(arrayDecoded.indices());
auto elements = baseArray->elements();
exec::LocalSelectivityVector nestedRows(context, elements->size());
nestedRows.get()->setAll();
exec::LocalDecodedVector elementsHolder(
context, *elements, *nestedRows.get());
auto& elementsDecoded = *elementsHolder.get();
decodedVectors.push_back(std::move(elementsDecoded));
continue;
}
// Handles string arg.
argMapping.push_back(i);
if (args[i] && args[i]->as<ConstantVector<StringView>>() &&
!args[i]->as<ConstantVector<StringView>>()->isNullAt(0)) {
std::ostringstream out;
out << args[i]->as<ConstantVector<StringView>>()->valueAt(0);
column_index_t j = i + 1;
// Concat constant string args.
for (; j < numArgs; ++j) {
if (!args[j] || args[j]->typeKind() == TypeKind::ARRAY ||
!args[j]->as<ConstantVector<StringView>>() ||
args[j]->as<ConstantVector<StringView>>()->isNullAt(0)) {
break;
}
out << separator
<< args[j]->as<ConstantVector<StringView>>()->valueAt(0);
}
constantStrings.emplace_back(out.str());
i = j - 1;
} else {
constantStrings.push_back("");
}
}

// Number of string columns after combined constant ones.
auto numStringCols = constantStrings.size();
// For column string arg decoding.
std::vector<exec::LocalDecodedVector> decodedStringArgs;
decodedStringArgs.reserve(numStringCols);

for (auto i = 0; i < numStringCols; ++i) {
if (constantStrings[i].empty()) {
auto index = argMapping[i];
decodedStringArgs.emplace_back(context, *args[index], rows);
}
}

// Calculate the total number of bytes in the result.
size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
int32_t allElements = 0;
for (int i = 0; i < rawSizesVector.size(); i++) {
auto size = rawSizesVector[i][indicesVector[i][row]];
auto offset = rawOffsetsVector[i][indicesVector[i][row]];
for (int j = 0; j < size; ++j) {
if (!decodedVectors[i].isNullAt(offset + j)) {
auto element = decodedVectors[i].valueAt<StringView>(offset + j);
if (!element.empty()) {
allElements++;
totalResultBytes += element.size();
}
}
}
}
auto it = decodedStringArgs.begin();
for (int i = 0; i < numStringCols; i++) {
auto value = constantStrings[i].empty()
? (*it++)->valueAt<StringView>(row)
: StringView(constantStrings[i]);
if (!value.empty()) {
allElements++;
totalResultBytes += value.size();
}
}
if (allElements > 1) {
totalResultBytes += (allElements - 1) * separator.size();
}
});

// Allocate a string buffer.
auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes);
size_t bufferOffset = 0;
rows.applyToSelected([&](int row) {
const char* start = rawBuffer + bufferOffset;
size_t combinedSize = 0;
auto isFirst = true;
// For array arg.
int32_t i = 0;
// For string arg.
int32_t j = 0;
auto it = decodedStringArgs.begin();

auto copyToBuffer = [&](StringView value) {
if (value.empty()) {
return;
}
if (isFirst) {
isFirst = false;
} else {
// Add separator before the current value.
memcpy(rawBuffer + bufferOffset, separator.data(), separator.size());
bufferOffset += separator.size();
combinedSize += separator.size();
}
memcpy(rawBuffer + bufferOffset, value.data(), value.size());
combinedSize += value.size();
bufferOffset += value.size();
};

for (auto itArgs = args.begin() + 1; itArgs != args.end(); ++itArgs) {
if ((*itArgs)->typeKind() == TypeKind::ARRAY) {
auto size = rawSizesVector[i][indicesVector[i][row]];
auto offset = rawOffsetsVector[i][indicesVector[i][row]];
for (int k = 0; k < size; ++k) {
if (!decodedVectors[i].isNullAt(offset + k)) {
auto element = decodedVectors[i].valueAt<StringView>(offset + k);
copyToBuffer(element);
}
}
i++;
continue;
}
if (j >= numStringCols) {
continue;
}
StringView value;
if (constantStrings[j].empty()) {
value = (*it++)->valueAt<StringView>(row);
} else {
value = StringView(constantStrings[j]);
}
copyToBuffer(value);
j++;
}
flatResult.setNoCopy(row, StringView(start, combinedSize));
});
}

class ConcatWs : public exec::VectorFunction {
public:
explicit ConcatWs(const std::string& separator) : separator_(separator) {}

void apply(
const SelectivityVector& selected,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
context.ensureWritable(selected, VARCHAR(), result);
auto flatResult = result->asFlatVector<StringView>();
auto numArgs = args.size();
// If separator is NULL, result is NULL.
if (args[0]->isNullAt(0)) {
selected.applyToSelected([&](int row) { result->setNull(row, true); });
return;
}
// If only separator (not a NULL) is provided, result is an empty string.
if (numArgs == 1) {
selected.applyToSelected(
[&](int row) { flatResult->setNoCopy(row, StringView("")); });
return;
}
doApply(selected, args, context, separator_, *flatResult);
}

private:
const std::string separator_;
};

} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> instrSignatures() {
Expand Down Expand Up @@ -144,6 +340,46 @@ std::shared_ptr<exec::VectorFunction> makeLength(
return kLengthFunction;
}

std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures() {
return {// The second and folowing arguments are varchar or array(varchar).
// The argument type will be checked in makeConcatWs.
// varchar, [varchar], [array(varchar)], ... -> varchar.
exec::FunctionSignatureBuilder()
.returnType("varchar")
.constantArgumentType("varchar")
.argumentType("any")
.variableArity()
.build()};
}

std::shared_ptr<exec::VectorFunction> makeConcatWs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
auto numArgs = inputArgs.size();
VELOX_USER_CHECK(
numArgs >= 1,
"concat_ws requires one arguments at least, but got {}.",
numArgs);
for (auto& arg : inputArgs) {
VELOX_USER_CHECK(
arg.type->isVarchar() ||
(arg.type->isArray() &&
arg.type->asArray().elementType()->isVarchar()),
"concat_ws requires varchar or array(varchar) arguments, but got {}.",
arg.type->toString());
}

BaseVector* constantPattern = inputArgs[0].constantValue.get();
VELOX_USER_CHECK(
nullptr != constantPattern,
"concat_ws requires constant separator arguments.");

auto separator =
constantPattern->as<ConstantVector<StringView>>()->valueAt(0).str();
return std::make_shared<ConcatWs>(separator);
}

void encodeDigestToBase16(uint8_t* output, int digestSize) {
static unsigned char const kHexCodes[] = "0123456789abcdef";
for (int i = digestSize - 1; i >= 0; --i) {
Expand Down
7 changes: 7 additions & 0 deletions velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ std::shared_ptr<exec::VectorFunction> makeLength(
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& config);

std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures();

std::shared_ptr<exec::VectorFunction> makeConcatWs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& config);

/// Expands each char of the digest data to two chars,
/// representing the hex value of each digest char, in order.
/// Note: digestSize must be one-half of outputSize.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_executable(
ArraySortTest.cpp
BitwiseTest.cpp
ComparisonsTest.cpp
ConcatWsTest.cpp
DateTimeFunctionsTest.cpp
DecimalArithmeticTest.cpp
DecimalCompareTest.cpp
Expand Down
Loading

0 comments on commit 8203b08

Please sign in to comment.