Skip to content

Commit

Permalink
Fix column separator with NULLs
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Jun 24, 2024
1 parent 36d2eb4 commit 874a61f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
72 changes: 53 additions & 19 deletions velox/functions/sparksql/ConcatWs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class ConcatWs : public exec::VectorFunction {
explicit ConcatWs(const std::optional<std::string>& separator)
: separator_(separator) {}

bool isConstantSeparator() const {
return separator_.has_value();
}

// Calculate the total number of bytes in the result.
size_t calculateTotalResultBytes(
const SelectivityVector& rows,
Expand All @@ -47,6 +51,10 @@ class ConcatWs : public exec::VectorFunction {

size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
// NULL separator produces NULL result.
if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) {
return;
}
int32_t allElements = 0;
// Calculate size for array columns data.
for (int i = 0; i < arrayArgNum; i++) {
Expand Down Expand Up @@ -87,7 +95,7 @@ class ConcatWs : public exec::VectorFunction {
totalResultBytes += value.size();
}

int32_t separatorSize = separator_.has_value()
int32_t separatorSize = isConstantSeparator()
? separator_.value().size()
: decodedSeparator->valueAt<StringView>(row).size();

Expand All @@ -113,7 +121,7 @@ class ConcatWs : public exec::VectorFunction {
}
// Handles string arg.
argMapping.push_back(i);
if (!separator_.has_value()) {
if (!isConstantSeparator()) {
// Cannot concat consecutive constant string args in advance.
constantStrings.push_back("");
continue;
Expand Down Expand Up @@ -159,7 +167,8 @@ class ConcatWs : public exec::VectorFunction {
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
FlatVector<StringView>& flatResult) const {
VectorPtr& result) const {
auto& flatResult = *result->asFlatVector<StringView>();
std::vector<column_index_t> argMapping;
std::vector<std::string> constantStrings;
auto numArgs = args.size();
Expand All @@ -182,7 +191,7 @@ class ConcatWs : public exec::VectorFunction {
constantStrings,
decodedStringArgs);
exec::LocalDecodedVector decodedSeparator(context);
if (!separator_.has_value()) {
if (!isConstantSeparator()) {
decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows);
}

Expand Down Expand Up @@ -210,6 +219,11 @@ class ConcatWs : public exec::VectorFunction {
auto rawBuffer =
flatResult.getRawStringBufferWithSpace(totalResultBytes, true);
rows.applyToSelected([&](auto row) {
// NULL separtor produces NULL result.
if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) {
result->setNull(row, true);
return;
}
const char* start = rawBuffer;
auto isFirst = true;
// For array arg.
Expand Down Expand Up @@ -249,7 +263,7 @@ class ConcatWs : public exec::VectorFunction {
auto element = elementsDecoded->valueAt<StringView>(offset + k);
copyToBuffer(
element,
separator_.has_value()
isConstantSeparator()
? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
}
Expand All @@ -275,7 +289,7 @@ class ConcatWs : public exec::VectorFunction {
}
copyToBuffer(
value,
separator_.has_value()
isConstantSeparator()
? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
j++;
Expand All @@ -287,24 +301,43 @@ class ConcatWs : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
context.ensureWritable(rows, VARCHAR(), result);
auto flatResult = result->asFlatVector<StringView>();
auto numArgs = args.size();
// If separator is NULL, result is NULL.
if (args[0]->isNullAt(0)) {
rows.applyToSelected([&](auto row) { result->setNull(row, true); });
return;
if (isConstantSeparator()) {
auto constant = args[0]->as<ConstantVector<StringView>>();
if (constant->isNullAt(0)) {
auto localResult = BaseVector::createNullConstant(outputType, rows.end(), context.pool());
context.moveOrCopyResult(localResult, rows, result);
// rows.applyToSelected([&](auto row) { result->setNull(row, true); });
return;
}
}
// If only separator (not a NULL) is provided, result is an empty string.
if (numArgs == 1) {
rows.applyToSelected(
[&](auto row) { flatResult->setNoCopy(row, StringView("")); });
auto decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows);
// 1. Separator is constant and not a NULL.
// 2. Separator is column and have no NULL.
if (isConstantSeparator() || !decodedSeparator->mayHaveNulls()) {
rows.applyToSelected(
[&](auto row) { flatResult->setNoCopy(row, StringView("")); });
} else {
rows.applyToSelected(
[&](auto row) {
if (decodedSeparator->isNullAt(row)) {
result->setNull(row, true);
} else {
flatResult->setNoCopy(row, StringView(""));
}
});
}
return;
}
doApply(rows, args, context, *flatResult);
doApply(rows, args, context, result);
}

private:
Expand All @@ -327,13 +360,14 @@ exec::ExprPtr ConcatWsCallToSpecialForm::constructSpecialForm(
1,
"concat_ws requires one arguments at least, but got {}.",
numArgs);
for (const auto& arg : args) {
VELOX_USER_CHECK(args[0]->type()->isVarchar(), "The first argument of concat_ws must be a varchar.");
for (size_t i = 1; i < args.size(); i++){
VELOX_USER_CHECK(
arg->type()->isVarchar() ||
(arg->type()->isArray() &&
arg->type()->asArray().elementType()->isVarchar()),
"concat_ws requires varchar or array(varchar) arguments, but got {}.",
arg->type()->toString());
args[i]->type()->isVarchar() ||
(args[i]->type()->isArray() &&
args[i]->type()->asArray().elementType()->isVarchar()),
"The 2nd and following arguments for concat_ws should be varchar or array(varchar), but got {}.",
args[i]->type()->toString());
}

std::optional<std::string> separator = std::nullopt;
Expand Down
29 changes: 22 additions & 7 deletions velox/functions/sparksql/tests/ConcatWsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,26 @@ TEST_F(ConcatWsTest, arrayArgs) {
// One array arg.
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('--', c0)", makeRowVector({arrayVector}));
auto expected1 = makeFlatVector<StringView>({
auto expected = makeFlatVector<StringView>({
"red--blue",
"blue--yellow--orange",
"",
"",
"red--purple--green",
});
velox::test::assertEqualVectors(expected1, result);
velox::test::assertEqualVectors(expected, result);

// Two array args.
result = evaluate<SimpleVector<StringView>>(
"concat_ws('--', c0, c1)", makeRowVector({arrayVector, arrayVector}));
auto expected2 = makeFlatVector<StringView>({
expected = makeFlatVector<StringView>({
"red--blue--red--blue",
"blue--yellow--orange--blue--yellow--orange",
"",
"",
"red--purple--green--red--purple--green",
});
velox::test::assertEqualVectors(expected2, result);
velox::test::assertEqualVectors(expected, result);
}

TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {
Expand Down Expand Up @@ -235,7 +235,7 @@ TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {

TEST_F(ConcatWsTest, nonconstantSeparator) {
auto separatorVector =
makeFlatVector<StringView>({"##", "--", "~~", "**", "++"});
makeNullableFlatVector<StringView>({"##", "--", "~~", "**", std::nullopt});
auto arrayVector = makeNullableArrayVector<StringView>({
{"red", "blue"},
{"blue", std::nullopt, "yellow", std::nullopt, "orange"},
Expand All @@ -246,12 +246,27 @@ TEST_F(ConcatWsTest, nonconstantSeparator) {

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws(c0, c1, '|')", makeRowVector({separatorVector, arrayVector}));
auto expected = makeFlatVector<StringView>({
auto expected = makeNullableFlatVector<StringView>({
"red##blue##|",
"blue--yellow--orange--|",
"red~~blue~~|",
"blue**yellow**orange**|",
"red++purple++green++|",
std::nullopt,
});
velox::test::assertEqualVectors(expected, result);
}

TEST_F(ConcatWsTest, separatorOnly) {
auto separatorVector =
makeNullableFlatVector<StringView>({"##", std::nullopt, "~~", "**", std::nullopt});
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws(c0)", makeRowVector({separatorVector}));
auto expected = makeNullableFlatVector<StringView>({
"",
std::nullopt,
"",
"",
std::nullopt,
});
velox::test::assertEqualVectors(expected, result);
}
Expand Down

0 comments on commit 874a61f

Please sign in to comment.