Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
[SQL-DS-CACHE-89][POAE7-1016] fix aggregation will null value issue, …
Browse files Browse the repository at this point in the history
…remote avg node (#90)
  • Loading branch information
jikunshang authored Apr 15, 2021
1 parent 3af1eb5 commit d1b1b50
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 37 deletions.
22 changes: 21 additions & 1 deletion oap-ape/ape-native/src/utils/AggExpression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,28 @@ int AggExpression::ExecuteWithParam(int batchSize,
std::vector<int8_t>& outBuffers) {
if (!done) {
child->ExecuteWithParam(batchSize, dataBuffers, nullBuffers, outBuffers);
done = true;
}
return 0;
}

void Count::getResult(DecimalVector& result) {
if (typeid(*child) == typeid(LiteralExpression)) { // for count(*) or count(1)
result.data.push_back(arrow::BasicDecimal128(batchSize_));
return;
}
if (!done) {
done = true;
auto tmp = DecimalVector();
child->getResult(tmp);
ARROW_LOG(INFO) << "count node child size: " << tmp.data.size();
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) count++;
}
}
result.data.push_back(arrow::BasicDecimal128(count));
result.type = ResultType::LongType;
}

int ArithmeticExpression::ExecuteWithParam(int batchSize,
const std::vector<int64_t>& dataBuffers,
const std::vector<int64_t>& nullBuffers,
Expand All @@ -89,6 +106,9 @@ int AttributeReferenceExpression::ExecuteWithParam(
done = true;
int64_t dataPtr = dataBuffers[columnIndex];
int64_t nullPtr = nullBuffers[columnIndex];
std::vector<uint8_t> nullVec(batchSize);
std::memcpy(nullVec.data(), (uint8_t*)nullPtr, batchSize);
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
parquet::Type::type columnType = (*schema)[columnIndex].getColType();
if (isDecimalType(dataType)) {
int precision, scale;
Expand Down
101 changes: 65 additions & 36 deletions oap-ape/ape-native/src/utils/AggExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class RootAggExpression : public WithResultExpression {
child->setSchema(schema);
}

void getResult(DecimalVector& result) { child->getResult(result); }
void getResult(DecimalVector& result) override { child->getResult(result); }

private:
bool isDistinct;
Expand All @@ -101,6 +101,18 @@ class AggExpression : public WithResultExpression {
done = false;
child->reset();
}
void getResult(DecimalVector& result) override {
if (!done) {
done = true;
getResultInternal(resultCache);
}
result = resultCache;
}

// build cached DecimalVector resultCache
virtual void getResultInternal(DecimalVector& result) {
ARROW_LOG(INFO) << "should never be called";
};

void setSchema(std::shared_ptr<std::vector<Schema>> schema_) {
schema = schema_;
Expand All @@ -109,18 +121,22 @@ class AggExpression : public WithResultExpression {

protected:
std::shared_ptr<WithResultExpression> child;
DecimalVector resultCache;
};

class Sum : public AggExpression {
public:
~Sum() {}
void getResult(DecimalVector& result) override {
void getResultInternal(DecimalVector& result) override {
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 out;
for (auto e : tmp.data) {
out += e;
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) {
out += tmp.data[i];
}
}
result.data.clear();
result.data.push_back(out);
result.precision = 38; // tmp.precision;
result.scale = tmp.scale;
Expand All @@ -131,11 +147,17 @@ class Sum : public AggExpression {
class Min : public AggExpression {
public:
~Min() {}
void getResult(DecimalVector& result) override {

void getResultInternal(DecimalVector& result) override {
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 out(tmp.data[0]);
for (auto e : tmp.data) out = out < e ? out : e;
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) {
out = out < tmp.data[i] ? out : tmp.data[i];
}
}
result.data.clear();
result.data.push_back(out);
result.precision = tmp.precision;
result.scale = tmp.scale;
Expand All @@ -146,11 +168,16 @@ class Min : public AggExpression {
class Max : public AggExpression {
public:
~Max() {}
void getResult(DecimalVector& result) override {
void getResultInternal(DecimalVector& result) override {
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 out(tmp.data[0]);
for (auto e : tmp.data) out = out > e ? out : e;
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) {
out = out > tmp.data[i] ? out : tmp.data[i];
}
}
result.data.clear();
result.data.push_back(out);
result.precision = tmp.precision;
result.scale = tmp.scale;
Expand All @@ -161,42 +188,21 @@ class Max : public AggExpression {
class Count : public AggExpression {
public:
~Count() {}
void getResult(DecimalVector& result) override {
result.data.push_back(arrow::BasicDecimal128(count));
result.type = ResultType::LongType;
}

void getResult(DecimalVector& result) override;
int ExecuteWithParam(int batchSize, const std::vector<int64_t>& dataBuffers,
const std::vector<int64_t>& nullBuffers,
std::vector<int8_t>& outBuffers) override {
if (!done) {
done = true;
count = batchSize;
count = 0;
batchSize_ = batchSize; // for count(*)
child->ExecuteWithParam(batchSize, dataBuffers, nullBuffers, outBuffers);
}
return 0;
}

private:
int count = 0;
};

class Avg : public AggExpression {
public:
~Avg() {}
void getResult(DecimalVector& result) override {
// should never be called
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 sum;
for (auto e : tmp.data) {
sum += e;
}
result.data.push_back(sum);
result.data.push_back(arrow::BasicDecimal128(tmp.data.size()));
result.precision = 38; // tmp.precision;
result.scale = tmp.scale;
result.type = GetResultType(dataType);
}
int batchSize_ = 0;
};

class ArithmeticExpression : public WithResultExpression {
Expand Down Expand Up @@ -284,16 +290,23 @@ class Add : public ArithmeticExpression {
arrow::BasicDecimal128 out = left.data[0] + right.data[i];
result.data.push_back(out);
}
result.nullVector = right.nullVector;
} else if (right.data.size() == 1) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] + right.data[0];
result.data.push_back(out);
}
result.nullVector = left.nullVector;
} else if (left.data.size() == right.data.size()) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] + right.data[i];
result.data.push_back(out);
}
std::vector<uint8_t> nullVec(left.data.size());
for (int i = 0; i < left.data.size(); i++) {
nullVec[i] = left.nullVector->at(i) & right.nullVector->at(i);
}
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
} else {
ARROW_LOG(ERROR) << "Oops...why left and right has different size?";
}
Expand Down Expand Up @@ -331,16 +344,23 @@ class Sub : public ArithmeticExpression {
arrow::BasicDecimal128 out = left.data[0] - right.data[i];
result.data.push_back(out);
}
result.nullVector = right.nullVector;
} else if (right.data.size() == 1) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] - right.data[0];
result.data.push_back(out);
}
result.nullVector = left.nullVector;
} else if (left.data.size() == right.data.size()) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] - right.data[i];
result.data.push_back(out);
}
std::vector<uint8_t> nullVec(left.data.size());
for (int i = 0; i < left.data.size(); i++) {
nullVec[i] = left.nullVector->at(i) & right.nullVector->at(i);
}
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
} else {
ARROW_LOG(ERROR) << "Oops...why left and right has different size?";
}
Expand Down Expand Up @@ -378,22 +398,30 @@ class Multiply : public ArithmeticExpression {
arrow::BasicDecimal128 out = left.data[0] * right.data[i];
result.data.push_back(out);
}
result.nullVector = right.nullVector;
} else if (right.data.size() == 1) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] * right.data[0];
result.data.push_back(out);
}
result.nullVector = left.nullVector;
} else if (left.data.size() == right.data.size()) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] * right.data[i];
result.data.push_back(out);
}
std::vector<uint8_t> nullVec(left.data.size());
for (int i = 0; i < left.data.size(); i++) {
nullVec[i] = left.nullVector->at(i) & right.nullVector->at(i);
}
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
} else {
ARROW_LOG(ERROR) << "Oops...why left and right has different size?";
}
}
};

// TODO: Impl Divide and Mod.
class Divide : public ArithmeticExpression {
public:
~Divide() {}
Expand Down Expand Up @@ -427,6 +455,7 @@ class AttributeReferenceExpression : public WithResultExpression {
}
res.precision = result.precision;
res.scale = result.scale;
res.nullVector = result.nullVector;
}

void setAttribute(std::string columnName_, std::string dataType_, std::string castType_,
Expand Down Expand Up @@ -458,6 +487,8 @@ class LiteralExpression : public WithResultExpression {
res.data.push_back(value);
res.precision = precision_;
res.scale = scale_;
std::vector<uint8_t> nullVec{1};
res.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
}
void setAttribute(std::string dataType_, std::string valueString_) {
dataType = dataType_;
Expand Down Expand Up @@ -511,8 +542,6 @@ class Gen {
return std::make_shared<Max>();
else if (name.compare("Min") == 0)
return std::make_shared<Min>();
else if (name.compare("Average") == 0)
return std::make_shared<Avg>();
else if (name.compare("Count") == 0)
return std::make_shared<Count>();

Expand Down
2 changes: 2 additions & 0 deletions oap-ape/ape-native/src/utils/DecimalConvertor.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ struct DecimalVector {
int32_t precision;
int32_t scale;
ResultType type;
std::shared_ptr<std::vector<uint8_t>> nullVector = nullptr;
void operator=(const DecimalVector& lhs) {
this->data = lhs.data;
this->precision = lhs.precision;
this->scale = lhs.scale;
this->type = lhs.type;
this->nullVector = lhs.nullVector;
}
};

Expand Down

0 comments on commit d1b1b50

Please sign in to comment.