Skip to content

Commit

Permalink
[GLUTEN-4652][VL] Fix min_by/max_by result mismatch (apache#5544)
Browse files Browse the repository at this point in the history
Fix min_by/max_by result mismatch. Take max_by for example, we need to keep intermediate result row like <null, 11> which will be compared with another result like <5, 8> and assure final result is <null, 11>.
  • Loading branch information
yma11 authored Apr 29, 2024
1 parent 7dad958 commit 049a477
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,13 @@ object VeloxIntermediateData {
* row_constructor_with_null.
*/
def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match {
case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor"
case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] =>
"row_constructor"
// For agg function min_by/max_by, it needs to keep rows with null value but non-null
// comparison, such as <null, 5>. So we set the struct to null when all of the arguments
// are null
case _: MaxMinBy =>
"row_constructor_with_all_null"
case _ => "row_constructor_with_null"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
override protected val resourcePath: String = "/tpch-data-parquet-velox"
override protected val fileFormat: String = "parquet"

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()
Expand Down Expand Up @@ -188,6 +190,22 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}

test("min_by/max_by") {
withTempPath {
path =>
Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null: Integer, 5: Integer))
.toDF("a", "b")
.write
.parquet(path.getCanonicalPath)
spark.read
.parquet(path.getCanonicalPath)
.createOrReplaceTempView("test")
runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") {
checkGlutenOperatorMatch[HashAggregateExecTransformer]
}
}
}

test("groupby") {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
Expand Down
13 changes: 11 additions & 2 deletions cpp/velox/operators/functions/RegistrationAllFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
#include "operators/functions/RegistrationAllFunctions.h"
#include "operators/functions/Arithmetic.h"
#include "operators/functions/RowConstructorWithAllNull.h"
#include "operators/functions/RowConstructorWithNull.h"
#include "operators/functions/RowFunctionWithNull.h"

Expand Down Expand Up @@ -47,11 +48,19 @@ void registerFunctionOverwrite() {
velox::exec::registerVectorFunction(
"row_constructor_with_null",
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull>(),
RowFunctionWithNull::metadata());
std::make_unique<RowFunctionWithNull</*allNull=*/false>>(),
RowFunctionWithNull</*allNull=*/false>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull,
std::make_unique<RowConstructorWithNullCallToSpecialForm>());
velox::exec::registerVectorFunction(
"row_constructor_with_all_null",
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull</*allNull=*/true>>(),
RowFunctionWithNull</*allNull=*/true>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull,
std::make_unique<RowConstructorWithAllNullCallToSpecialForm>());
velox::functions::sparksql::registerBitwiseFunctions("spark_");
}
} // namespace
Expand Down
37 changes: 37 additions & 0 deletions cpp/velox/operators/functions/RowConstructorWithAllNull.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "RowConstructorWithNull.h"

namespace gluten {
class RowConstructorWithAllNullCallToSpecialForm : public RowConstructorWithNullCallToSpecialForm {
public:
static constexpr const char* kRowConstructorWithAllNull = "row_constructor_with_all_null";

protected:
facebook::velox::exec::ExprPtr constructSpecialForm(
const std::string& name,
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config) {
return constructSpecialForm(kRowConstructorWithAllNull, type, std::move(compiledChildren), trackCpuUsage, config);
}
};
} // namespace gluten
21 changes: 18 additions & 3 deletions cpp/velox/operators/functions/RowFunctionWithNull.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
namespace gluten {

/**
* A customized RowFunction to set struct as null when one of its argument is null.
* @tparam allNull If true, set struct as null when all of arguments are all, else will
* set it null when one of its arguments is null.
*/
template <bool allNull>
class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction {
public:
void apply(
Expand All @@ -42,13 +44,26 @@ class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction {
rows.applyToSelected([&](facebook::velox::vector_size_t i) {
facebook::velox::bits::clearNull(nullsPtr, i);
if (!facebook::velox::bits::isBitNull(nullsPtr, i)) {
int argsNullCnt = 0;
for (size_t c = 0; c < argsCopy.size(); c++) {
auto arg = argsCopy[c].get();
if (arg->mayHaveNulls() && arg->isNullAt(i)) {
// If any argument of the struct is null, set the struct as null.
// For row_constructor_with_null, if any argument of the struct is null,
// set the struct as null.
if constexpr (!allNull) {
facebook::velox::bits::setNull(nullsPtr, i, true);
cntNull++;
break;
} else {
argsNullCnt++;
}
}
}
// For row_constructor_with_all_null, set the struct to be null when all arguments are all
if constexpr (allNull) {
if (argsNullCnt == argsCopy.size()) {
facebook::velox::bits::setNull(nullsPtr, i, true);
cntNull++;
break;
}
}
}
Expand Down

0 comments on commit 049a477

Please sign in to comment.