diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 817eaad34ab29..be4f5d06fab0c 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -20,6 +20,21 @@ Unless specified otherwise, all functions return NULL if at least one of the arg If ``n < 0``, the result is an empty string. If ``n >= 256``, the result is equivalent to chr(``n % 256``). +.. spark:function:: concat_ws(separator, [string]/[array], ...) -> varchar + + Returns the concatenation for ``string`` & all elements in ``array``, separated by + ``separator``. Only accepts constant ``separator``. It takes variable number of remaining + arguments. And ``string`` & ``array`` can be used in combination. If ``separator`` + is NULL, returns NULL, regardless of the following inputs. If only ``separator`` (not a + NULL) is provided or all remaining inputs are NULL, returns an empty string. :: + + SELECT concat_ws('~', 'a', 'b', 'c'); -- 'a~b~c' + SELECT concat_ws('~', ['a', 'b', 'c'], ['d']); -- 'a~b~c~d' + SELECT concat_ws('~', 'a', ['b', 'c']); -- 'a~b~c' + SELECT concat_ws(NULL, 'a'); -- NULL + SELECT concat_ws('~'); -- '' + SELECT concat_ws('~', NULL, NULL); -- '' + .. spark:function:: contains(left, right) -> boolean Returns true if 'right' is found in 'left'. Otherwise, returns false. :: diff --git a/velox/expression/tests/SparkExpressionFuzzerTest.cpp b/velox/expression/tests/SparkExpressionFuzzerTest.cpp index c9531632f4137..9ac8381a66b8d 100644 --- a/velox/expression/tests/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/tests/SparkExpressionFuzzerTest.cpp @@ -54,7 +54,11 @@ int main(int argc, char** argv) { "chr", "replace", "might_contain", - "unix_timestamp"}; + "unix_timestamp", + // Skip concat_ws as it triggers a test failure due to an incorrect + // expression generation from fuzzer: + // https://github.com/facebookincubator/velox/issues/6590 + "concat_ws"}; // Required by spark_partition_id function. std::unordered_map queryConfigs = { diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 72a8ca5275684..1a39be13c30c3 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -216,6 +216,11 @@ void registerFunctions(const std::string& prefix) { prefix + "length", lengthSignatures(), makeLength); registerFunction( {prefix + "substring_index"}); + exec::registerStatefulVectorFunction( + prefix + "concat_ws", + concatWsSignatures(), + makeConcatWs, + exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); registerFunction({prefix + "md5"}); registerFunction( diff --git a/velox/functions/sparksql/String.cpp b/velox/functions/sparksql/String.cpp index 8e72c8614ccac..cf73cd81b5b08 100644 --- a/velox/functions/sparksql/String.cpp +++ b/velox/functions/sparksql/String.cpp @@ -103,6 +103,203 @@ class Length : public exec::VectorFunction { } }; +void doApply( + const SelectivityVector& rows, + std::vector& args, + exec::EvalCtx& context, + const std::string& separator, + FlatVector& flatResult) { + std::vector argMapping; + std::vector constantStrings; + auto numArgs = args.size(); + // Save constant values to constantStrings_. + // Identify and combine consecutive constant inputs. + argMapping.reserve(numArgs - 1); + constantStrings.reserve(numArgs - 1); + // For each array arg, save rawSizes, rawOffsets, indices, and elements + // BaseBector. + std::vector rawSizesVector; + std::vector rawOffsetsVector; + std::vector indicesVector; + std::vector decodedVectors; + + for (auto i = 1; i < numArgs; ++i) { + // If array arg, continue. + if (args[i] && args[i]->typeKind() == TypeKind::ARRAY) { + exec::LocalDecodedVector arrayHolder(context, *args[i], rows); + auto& arrayDecoded = *arrayHolder.get(); + auto baseArray = arrayDecoded.base()->as(); + rawSizesVector.push_back(baseArray->rawSizes()); + rawOffsetsVector.push_back(baseArray->rawOffsets()); + indicesVector.push_back(arrayDecoded.indices()); + auto elements = baseArray->elements(); + exec::LocalSelectivityVector nestedRows(context, elements->size()); + nestedRows.get()->setAll(); + exec::LocalDecodedVector elementsHolder( + context, *elements, *nestedRows.get()); + auto& elementsDecoded = *elementsHolder.get(); + decodedVectors.push_back(std::move(elementsDecoded)); + continue; + } + // Handles string arg. + argMapping.push_back(i); + if (args[i] && args[i]->as>() && + !args[i]->as>()->isNullAt(0)) { + std::ostringstream out; + out << args[i]->as>()->valueAt(0); + column_index_t j = i + 1; + // Concat constant string args. + for (; j < numArgs; ++j) { + if (!args[j] || args[j]->typeKind() == TypeKind::ARRAY || + !args[j]->as>() || + args[j]->as>()->isNullAt(0)) { + break; + } + out << separator + << args[j]->as>()->valueAt(0); + } + constantStrings.emplace_back(out.str()); + i = j - 1; + } else { + constantStrings.push_back(""); + } + } + + // Number of string columns after combined constant ones. + auto numStringCols = constantStrings.size(); + // For column string arg decoding. + std::vector decodedStringArgs; + decodedStringArgs.reserve(numStringCols); + + for (auto i = 0; i < numStringCols; ++i) { + if (constantStrings[i].empty()) { + auto index = argMapping[i]; + decodedStringArgs.emplace_back(context, *args[index], rows); + } + } + + // Calculate the total number of bytes in the result. + size_t totalResultBytes = 0; + rows.applyToSelected([&](auto row) { + int32_t allElements = 0; + for (int i = 0; i < rawSizesVector.size(); i++) { + auto size = rawSizesVector[i][indicesVector[i][row]]; + auto offset = rawOffsetsVector[i][indicesVector[i][row]]; + for (int j = 0; j < size; ++j) { + if (!decodedVectors[i].isNullAt(offset + j)) { + auto element = decodedVectors[i].valueAt(offset + j); + if (!element.empty()) { + allElements++; + totalResultBytes += element.size(); + } + } + } + } + auto it = decodedStringArgs.begin(); + for (int i = 0; i < numStringCols; i++) { + auto value = constantStrings[i].empty() + ? (*it++)->valueAt(row) + : StringView(constantStrings[i]); + if (!value.empty()) { + allElements++; + totalResultBytes += value.size(); + } + } + if (allElements > 1) { + totalResultBytes += (allElements - 1) * separator.size(); + } + }); + + // Allocate a string buffer. + auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes); + size_t bufferOffset = 0; + rows.applyToSelected([&](int row) { + const char* start = rawBuffer + bufferOffset; + size_t combinedSize = 0; + auto isFirst = true; + // For array arg. + int32_t i = 0; + // For string arg. + int32_t j = 0; + auto it = decodedStringArgs.begin(); + + auto copyValue = [&](StringView value) { + if (value.empty()) { + return; + } + if (isFirst) { + isFirst = false; + } else { + // Add separator before the current value. + memcpy(rawBuffer + bufferOffset, separator.data(), separator.size()); + bufferOffset += separator.size(); + combinedSize += separator.size(); + } + memcpy(rawBuffer + bufferOffset, value.data(), value.size()); + combinedSize += value.size(); + bufferOffset += value.size(); + }; + + for (auto itArgs = args.begin() + 1; itArgs != args.end(); ++itArgs) { + if ((*itArgs)->typeKind() == TypeKind::ARRAY) { + auto size = rawSizesVector[i][indicesVector[i][row]]; + auto offset = rawOffsetsVector[i][indicesVector[i][row]]; + for (int k = 0; k < size; ++k) { + if (!decodedVectors[i].isNullAt(offset + k)) { + auto element = decodedVectors[i].valueAt(offset + k); + copyValue(element); + } + } + i++; + continue; + } + if (j >= numStringCols) { + continue; + } + StringView value; + if (constantStrings[j].empty()) { + value = (*it++)->valueAt(row); + } else { + value = StringView(constantStrings[j]); + } + copyValue(value); + j++; + } + flatResult.setNoCopy(row, StringView(start, combinedSize)); + }); +} + +class ConcatWs : public exec::VectorFunction { + public: + explicit ConcatWs(const std::string& separator) : separator_(separator) {} + + void apply( + const SelectivityVector& selected, + std::vector& args, + const TypePtr& /* outputType */, + exec::EvalCtx& context, + VectorPtr& result) const override { + context.ensureWritable(selected, VARCHAR(), result); + auto flatResult = result->asFlatVector(); + auto numArgs = args.size(); + // If separator is NULL, result is NULL. + if (args[0]->isNullAt(0)) { + selected.applyToSelected([&](int row) { result->setNull(row, true); }); + return; + } + // If only separator (not a NULL) is provided, result is an empty string. + if (numArgs == 1) { + selected.applyToSelected( + [&](int row) { flatResult->setNoCopy(row, StringView("")); }); + return; + } + doApply(selected, args, context, separator_, *flatResult); + } + + private: + const std::string separator_; +}; + } // namespace std::vector> instrSignatures() { @@ -144,6 +341,46 @@ std::shared_ptr makeLength( return kLengthFunction; } +std::vector> concatWsSignatures() { + return {// The second and folowing arguments are varchar or array(varchar). + // The argument type will be checked in makeConcatWs. + // varchar, [varchar], [array(varchar)], ... -> varchar. + exec::FunctionSignatureBuilder() + .returnType("varchar") + .constantArgumentType("varchar") + .argumentType("any") + .variableArity() + .build()}; +} + +std::shared_ptr makeConcatWs( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& /*config*/) { + auto numArgs = inputArgs.size(); + VELOX_USER_CHECK( + numArgs >= 1, + "concat_ws requires one arguments at least, but got {}.", + numArgs); + for (auto& arg : inputArgs) { + VELOX_USER_CHECK( + arg.type->isVarchar() || + (arg.type->isArray() && + arg.type->asArray().elementType()->isVarchar()), + "concat_ws requires varchar or array(varchar) arguments, but got {}.", + arg.type->toString()); + } + + BaseVector* constantPattern = inputArgs[0].constantValue.get(); + VELOX_USER_CHECK( + nullptr != constantPattern, + "concat_ws requires constant separator arguments."); + + auto separator = + constantPattern->as>()->valueAt(0).str(); + return std::make_shared(separator); +} + void encodeDigestToBase16(uint8_t* output, int digestSize) { static unsigned char const kHexCodes[] = "0123456789abcdef"; for (int i = digestSize - 1; i >= 0; --i) { diff --git a/velox/functions/sparksql/String.h b/velox/functions/sparksql/String.h index 75f9e90f3d3b9..265ecb8fda3ea 100644 --- a/velox/functions/sparksql/String.h +++ b/velox/functions/sparksql/String.h @@ -152,6 +152,13 @@ std::shared_ptr makeLength( const std::vector& inputArgs, const core::QueryConfig& config); +std::vector> concatWsSignatures(); + +std::shared_ptr makeConcatWs( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& config); + /// Expands each char of the digest data to two chars, /// representing the hex value of each digest char, in order. /// Note: digestSize must be one-half of outputSize. diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index e4881db973068..bb0ac8b5b399b 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -20,6 +20,7 @@ add_executable( ArraySortTest.cpp BitwiseTest.cpp ComparisonsTest.cpp + ConcatWsTest.cpp DateTimeFunctionsTest.cpp DecimalArithmeticTest.cpp DecimalCompareTest.cpp diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp new file mode 100644 index 0000000000000..5084a903d56ea --- /dev/null +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +#include + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class ConcatWsTest : public SparkFunctionBaseTest { + protected: + std::string generateRandomString(size_t length) { + const std::string chars = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::string randomString; + for (std::size_t i = 0; i < length; ++i) { + randomString += chars[folly::Random::rand32() % chars.size()]; + } + return randomString; + } + + void testConcatWsFlatVector( + const std::vector>& inputTable, + const size_t argsCount, + const std::string& separator) { + std::vector inputVectors; + + for (int i = 0; i < argsCount; i++) { + inputVectors.emplace_back( + BaseVector::create(VARCHAR(), inputTable.size(), execCtx_.pool())); + } + + for (int row = 0; row < inputTable.size(); row++) { + for (int col = 0; col < argsCount; col++) { + std::static_pointer_cast>(inputVectors[col]) + ->set(row, StringView(inputTable[row][col])); + } + } + + auto buildConcatQuery = [&]() { + std::string output = "concat_ws('" + separator + "'"; + + for (int i = 0; i < argsCount; i++) { + output += ",c" + std::to_string(i); + } + output += ")"; + return output; + }; + auto result = evaluate>( + buildConcatQuery(), makeRowVector(inputVectors)); + + auto produceExpectedResult = [&](const std::vector& inputs) { + auto isFirst = true; + std::string output; + for (int i = 0; i < inputs.size(); i++) { + auto value = inputs[i]; + if (!value.empty()) { + if (isFirst) { + isFirst = false; + } else { + output += separator; + } + output += value; + } + } + return output; + }; + + for (int i = 0; i < inputTable.size(); ++i) { + EXPECT_EQ(result->valueAt(i), produceExpectedResult(inputTable[i])) + << "at " << i; + } + } +}; + +TEST_F(ConcatWsTest, columnStringArgs) { + // Test with constant args. + auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10); + auto c0 = generateRandomString(20); + auto c1 = generateRandomString(20); + auto result = evaluate>( + fmt::format("concat_ws('-', '{}', '{}')", c0, c1), rows); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(result->valueAt(i), c0 + "-" + c1); + } + + // Test with variable arguments. + size_t maxArgsCount = 10; + size_t rowCount = 100; + size_t maxStringLength = 100; + + std::vector> inputTable; + for (int argsCount = 1; argsCount <= maxArgsCount; argsCount++) { + inputTable.clear(); + inputTable.resize(rowCount, std::vector(argsCount)); + + for (int row = 0; row < rowCount; row++) { + for (int col = 0; col < argsCount; col++) { + inputTable[row][col] = + generateRandomString(folly::Random::rand32() % maxStringLength); + } + } + + SCOPED_TRACE(fmt::format("Number of arguments: {}", argsCount)); + testConcatWsFlatVector(inputTable, argsCount, "--testSep--"); + } +} + +TEST_F(ConcatWsTest, mixedConstantAndColumnStringArgs) { + size_t maxStringLength = 100; + std::string value; + auto data = makeRowVector({ + makeFlatVector( + 1'000, + [&](auto /* row */) { + value = + generateRandomString(folly::Random::rand32() % maxStringLength); + return StringView(value); + }), + makeFlatVector( + 1'000, + [&](auto /* row */) { + value = + generateRandomString(folly::Random::rand32() % maxStringLength); + return StringView(value); + }), + }); + + auto c0 = data->childAt(0)->as>()->rawValues(); + auto c1 = data->childAt(1)->as>()->rawValues(); + + // Test with consecutive constant inputs. + auto result = evaluate>( + "concat_ws('--', c0, c1, 'foo', 'bar')", data); + auto expected = makeFlatVector(1'000, [&](auto row) { + value = ""; + const std::string& s0 = c0[row].str(); + const std::string& s1 = c1[row].str(); + + if (s0.empty() && s1.empty()) { + value = "foo--bar"; + } else if (!s0.empty() && !s1.empty()) { + value = s0 + "--" + s1 + "--foo--bar"; + } else { + value = s0 + s1 + "--foo--bar"; + } + return StringView(value); + }); + velox::test::assertEqualVectors(expected, result); + + // Test with non-ASCII characters. + result = evaluate>( + "concat_ws('$*@', 'aaa', '测试', c0, 'eee', 'ddd', c1, '\u82f9\u679c', 'fff')", + data); + expected = makeFlatVector(1'000, [&](auto row) { + value = ""; + std::string delim = "$*@"; + const std::string& s0 = + c0[row].str().empty() ? c0[row].str() : delim + c0[row].str(); + const std::string& s1 = + c1[row].str().empty() ? c1[row].str() : delim + c1[row].str(); + + value = "aaa" + delim + "测试" + s0 + delim + "eee" + delim + "ddd" + s1 + + delim + "\u82f9\u679c" + delim + "fff"; + return StringView(value); + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, arrayArgs) { + using S = StringView; + auto arrayVector = makeNullableArrayVector({ + {S("red"), S("blue")}, + {S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")}, + {}, + {std::nullopt}, + {S("red"), S("purple"), S("green")}, + }); + + // One array arg. + auto result = evaluate>( + "concat_ws('----', c0)", makeRowVector({arrayVector})); + auto expected1 = { + S("red----blue"), + S("blue----yellow----orange"), + S(""), + S(""), + S("red----purple----green"), + }; + velox::test::assertEqualVectors( + makeFlatVector(expected1), result); + + // Two array args. + result = evaluate>( + "concat_ws('----', c0, c1)", makeRowVector({arrayVector, arrayVector})); + auto expected2 = { + S("red----blue----red----blue"), + S("blue----yellow----orange----blue----yellow----orange"), + S(""), + S(""), + S("red----purple----green----red----purple----green"), + }; + velox::test::assertEqualVectors( + makeFlatVector(expected2), result); +} + +TEST_F(ConcatWsTest, mixedStringArrayArgs) { + using S = StringView; + auto arrayVector = makeNullableArrayVector({ + {S("red"), S("blue")}, + {S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")}, + {}, + {std::nullopt}, + {S("red"), S("purple"), S("green")}, + }); + + auto result = evaluate>( + "concat_ws('----', c0, 'foo', c1, 'bar', 'end')", + makeRowVector({arrayVector, arrayVector})); + auto expected = { + S("red----blue----foo----red----blue----bar----end"), + S("blue----yellow----orange----foo----blue----yellow----orange----bar----end"), + S("foo----bar----end"), + S("foo----bar----end"), + S("red----purple----green----foo----red----purple----green----bar----end"), + }; + velox::test::assertEqualVectors(makeFlatVector(expected), result); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test