Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IN operator in filter. #1871

Merged
merged 6 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ export namespace std {

using std::chrono::high_resolution_clock;
} // namespace chrono

using std::format;
using std::cout;
using std::cerr;
Expand Down Expand Up @@ -374,8 +374,8 @@ namespace infinity {
template<typename S, typename T, typename H = std::hash<S>>
using MultiHashMap = std::unordered_multimap<S, T, H>;

template<typename S>
using HashSet = std::unordered_set<S>;
template<typename S, typename T = std::hash<S>, typename Eq = std::equal_to<S>>
using HashSet = std::unordered_set<S, T, Eq>;

template<typename T>
using MaxHeap = std::priority_queue<T>;
Expand Down
22 changes: 19 additions & 3 deletions src/executor/expression/expression_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,25 @@ void ExpressionEvaluator::Execute(const SharedPtr<ReferenceExpression> &expr,
output_column_vector = input_data_block_->column_vectors[column_index];
}

void ExpressionEvaluator::Execute(const SharedPtr<InExpression> &, SharedPtr<ExpressionState> &, SharedPtr<ColumnVector> &) {
Status status = Status::NotSupport("IN execution isn't implemented yet.");
RecoverableError(status);
void ExpressionEvaluator::Execute(const SharedPtr<InExpression> &expr, SharedPtr<ExpressionState> &state, SharedPtr<ColumnVector> &output_column_vector) {
SharedPtr<BaseExpression> &left_expression = expr->left_operand();
SharedPtr<ExpressionState> &left_state = state->Children()[0];
SharedPtr<ColumnVector> &left_state_output = left_state->OutputColumnVector();
Execute(left_expression, left_state, left_state_output);

SizeT left_result_count = left_state_output->Size();
if(expr->in_type() == InType::kIn) {
for(SizeT idx = 0; idx < left_result_count; idx++) {
output_column_vector->buffer_->SetCompactBit(idx, expr->Exists(left_state_output->GetValue(idx)));
}
return;
}
if (expr->in_type() == InType::kNotIn) {
for(SizeT idx = 0; idx < left_result_count; idx++) {
output_column_vector->buffer_->SetCompactBit(idx, !expr->Exists(left_state_output->GetValue(idx)));
}
return;
}
}

void ExpressionEvaluator::Execute(const SharedPtr<FilterFulltextExpression> &expr,
Expand Down
15 changes: 4 additions & 11 deletions src/executor/expression/expression_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ SharedPtr<ExpressionState> ExpressionState::CreateState(const SharedPtr<BaseExpr
return nullptr;
}

SharedPtr<ExpressionState> ExpressionState::CreateState(const SharedPtr<AggregateExpression> &agg_expr, char *agg_state, const AggregateFlag agg_flag) {
SharedPtr<ExpressionState>
ExpressionState::CreateState(const SharedPtr<AggregateExpression> &agg_expr, char *agg_state, const AggregateFlag agg_flag) {
if (agg_expr->arguments().size() != 1) {
Status status = Status::FunctionArgsError(agg_expr->ToString());
RecoverableError(status);
Expand Down Expand Up @@ -127,7 +128,6 @@ SharedPtr<ExpressionState> ExpressionState::CreateState(const SharedPtr<CastExpr
SharedPtr<ExpressionState> result = MakeShared<ExpressionState>();
result->AddChild(cast_expr->arguments()[0]);


ColumnVectorType result_column_vector_type = ColumnVectorType::kFlat;
if (result->Children()[0]->OutputColumnVector()) {
result_column_vector_type = result->Children()[0]->OutputColumnVector()->vector_type();
Expand Down Expand Up @@ -188,16 +188,9 @@ SharedPtr<ExpressionState> ExpressionState::CreateState(const SharedPtr<InExpres

result->AddChild(in_expr->left_operand());

for (auto &argument_expr : in_expr->arguments()) {
result->AddChild(argument_expr);
}

ColumnVectorType result_column_vector_type = ColumnVectorType::kConstant;
for (SizeT idx = 0; idx < result->Children().size(); ++idx) {
if (result->Children()[idx]->OutputColumnVector()->vector_type() != ColumnVectorType::kConstant) {
result_column_vector_type = ColumnVectorType::kFlat;
break;
}
if (auto &column_ptr = result->Children()[0]->OutputColumnVector(); !column_ptr || column_ptr->vector_type() != ColumnVectorType::kConstant) {
result_column_vector_type = ColumnVectorType::kFlat;
}

result->column_vector_ = MakeShared<ColumnVector>(in_expr_data_type);
Expand Down
4 changes: 2 additions & 2 deletions src/expression/in_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ module in_expression;

namespace infinity {

InExpression::InExpression(InType in_type, SharedPtr<BaseExpression> left_operand, const Vector<SharedPtr<BaseExpression>> &value_list)
: BaseExpression(ExpressionType::kIn, value_list), left_operand_ptr_(std::move(left_operand)), in_type_(in_type) {}
InExpression::InExpression(InType in_type, SharedPtr<BaseExpression> left_operand, Vector<SharedPtr<BaseExpression>> arguments)
: BaseExpression(ExpressionType::kIn, arguments), left_operand_ptr_(std::move(left_operand)), in_type_(in_type), set_(left_operand_ptr_->Type().type()) {}

String InExpression::ToString() const {

Expand Down
106 changes: 104 additions & 2 deletions src/expression/in_expression.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,115 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module in_expression;

import column_binding;
import base_expression;
import data_type;
import value;
import infinity_exception;
import stl;
import logical_type;
import internal_types;

namespace infinity {

// in_expression supported types
// kBoolean = 0,
// kTinyInt,
// kSmallInt,
// kInteger,
// kBigInt,
// kHugeInt,
// kDecimal,
// kFloat,
// kDouble,
// kVarchar,

class ValueSet {
public:
void TryPut(Value &&val) {
if (val.type().type() != data_type_.type()) {
UnrecoverableError(std::format("Mismatched type in ValueSet : {}, {}", val.type().ToString(), data_type_.ToString()));
return;
}
set_.emplace(std::move(val));
}

inline bool Exist(const Value &val) const { return set_.contains(val); }
inline DataType Type() const { return data_type_; }

// constructor will throw when illegal type is passed
ValueSet(LogicalType logical_type) : data_type_(logical_type) {
switch (logical_type) {
case LogicalType::kBoolean:
break;
case LogicalType::kTinyInt:
break;
case LogicalType::kSmallInt:
break;
case LogicalType::kInteger:
break;
case LogicalType::kBigInt:
break;
case LogicalType::kHugeInt:
break;
case LogicalType::kDecimal:
break;
case LogicalType::kFloat:
break;
case LogicalType::kDouble:
break;
case LogicalType::kVarchar:
break;
default:
UnrecoverableError(std::format("Not supported type in ValueSet for InExpression: {}", LogicalType2Str(logical_type)));
return;
}
}

private:
struct ValueComparator {
bool operator()(const Value &lhs, const Value &rhs) const { return lhs == rhs; }
};

struct ValueHasher {
u64 operator()(const Value &val) const {
switch (val.type().type()) {
case LogicalType::kBoolean:
return std::hash<BooleanT>{}(val.GetValue<BooleanT>());
case LogicalType::kTinyInt:
return std::hash<TinyIntT>{}(val.GetValue<TinyIntT>());
case LogicalType::kSmallInt:
return std::hash<SmallIntT>{}(val.GetValue<SmallIntT>());
case LogicalType::kInteger:
return std::hash<IntegerT>{}(val.GetValue<IntegerT>());
case LogicalType::kBigInt:
return std::hash<BigIntT>{}(val.GetValue<BigIntT>());
case LogicalType::kHugeInt:
return val.GetValue<HugeIntT>().GetHash();
case LogicalType::kDecimal:
return val.GetValue<DecimalT>().GetHash();
case LogicalType::kFloat:
return std::hash<FloatT>{}(val.GetValue<FloatT>());
case LogicalType::kDouble:
return std::hash<DoubleT>{}(val.GetValue<DoubleT>());
case LogicalType::kVarchar:
return std::hash<String>{}(val.GetVarchar());
default:
String error_message = std::format("Not supported type : {}", val.type().ToString());
UnrecoverableError(error_message);
break;
}
return 0;
}
};
DataType data_type_;
HashSet<Value, ValueHasher, ValueComparator> set_;
};

export enum class InType {
kInvalid,
kIn,
Expand All @@ -33,7 +128,7 @@ export enum class InType {

export class InExpression : public BaseExpression {
public:
InExpression(InType in_type, SharedPtr<BaseExpression> left_operand, const Vector<SharedPtr<BaseExpression>> &value_list);
InExpression(InType in_type, SharedPtr<BaseExpression> left_operand, Vector<SharedPtr<BaseExpression>> arguments);

String ToString() const override;

Expand All @@ -45,9 +140,16 @@ public:

inline InType in_type() const { return in_type_; }

inline void TryPut(Value &&val) { set_.TryPut(std::move(val)); }

inline bool Exists(const Value &val) const { return set_.Exist(val); }

inline DataType TypeOfArguments() const { return set_.Type(); }

private:
SharedPtr<BaseExpression> left_operand_ptr_;
InType in_type_;
ValueSet set_;
};

} // namespace infinity
6 changes: 3 additions & 3 deletions src/parser/expr/in_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ InExpr::~InExpr() {
std::string InExpr::ToString() const {
std::stringstream ss;
ss << left_->ToString();
if (not_in_) {
ss << "NOT IN (";
} else {
if (in_) {
ss << "IN (";
} else {
ss << "NOT IN (";
}
if (arguments_ != nullptr) {
for (ParsedExpr *expr_ptr : *arguments_) {
Expand Down
4 changes: 2 additions & 2 deletions src/parser/expr/in_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace infinity {

class InExpr : public ParsedExpr {
public:
explicit InExpr(bool not_in = false) : ParsedExpr(ParsedExprType::kIn), not_in_(not_in) {}
explicit InExpr(bool in = true) : ParsedExpr(ParsedExprType::kIn), in_(in) {}

~InExpr() override;

Expand All @@ -31,7 +31,7 @@ class InExpr : public ParsedExpr {
public:
ParsedExpr *left_{nullptr};
std::vector<ParsedExpr *> *arguments_{nullptr};
bool not_in_{false};
bool in_{false};
};

} // namespace infinity
62 changes: 58 additions & 4 deletions src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ import function_set;
import scalar_function_set;
import scalar_function;
import special_function;
import cast_function;
import bound_cast_func;
import status;

import query_context;
Expand All @@ -81,6 +83,7 @@ import data_type;
import expression_type;
import catalog;
import table_entry;
import column_vector;

namespace infinity {

Expand Down Expand Up @@ -542,18 +545,69 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildInExpr(const InExpr &expr, Bind
Vector<SharedPtr<BaseExpression>> arguments;
arguments.reserve(argument_count);

//in operator, all data type shouhld be the same
SharedPtr<DataType> arguments_type = nullptr;

for (SizeT idx = 0; idx < argument_count; ++idx) {
if (expr.arguments_->at(idx)->type_ != ParsedExprType::kConstant) {
Status status = Status::SyntaxError("In expression now only supports constant list!");
RecoverableError(status);
}
auto bound_argument_expr = BuildExpression(*expr.arguments_->at(idx), bind_context_ptr, depth, false);

if(arguments_type != nullptr && bound_argument_expr->Type() != *arguments_type) {
Status status = Status::SyntaxError("Expressions in In expression must be of the same data type!");
RecoverableError(status);
} else if(arguments_type == nullptr){
arguments_type = MakeShared<DataType>(bound_argument_expr->Type());
}

arguments.emplace_back(bound_argument_expr);
}

InType in_type{InType::kIn};
if (expr.not_in_) {
in_type = InType::kNotIn;
} else {
if (expr.in_) {
in_type = InType::kIn;
} else {
in_type = InType::kNotIn;
}

SharedPtr<InExpression> in_expression_ptr = MakeShared<InExpression>(in_type, bound_left_expr, arguments);

//if match
if(arguments_type->type() == bound_left_expr->Type().type()) {
for(SizeT idx = 0; idx < argument_count; idx++) {
ValueExpression* val_expr = static_cast<ValueExpression *>(arguments[idx].get());
Value val = val_expr->GetValue();
in_expression_ptr->TryPut(std::move(val));
}
} else if(CastExpression::CanCast(*arguments_type, bound_left_expr->Type())) {
//cast
BoundCastFunc cast = CastFunction::GetBoundFunc(*arguments_type, bound_left_expr->Type());

SharedPtr<ColumnVector> argument_column_vector = MakeShared<ColumnVector>(arguments_type);
argument_column_vector->Initialize(ColumnVectorType::kFlat, DEFAULT_VECTOR_SIZE);

for(SizeT idx = 0; idx < argument_count; idx++) {
ValueExpression* val_expr = static_cast<ValueExpression *>(arguments[idx].get());
argument_column_vector->AppendValue(val_expr->GetValue());
}

SharedPtr<ColumnVector> cast_column_vector = MakeShared<ColumnVector>(MakeShared<DataType>(bound_left_expr->Type()));
//will overflow when passing argument_count
cast_column_vector->Initialize(ColumnVectorType::kFlat, DEFAULT_VECTOR_SIZE);
CastParameters cast_parameters;
cast.function(argument_column_vector, cast_column_vector, argument_count, cast_parameters);

for(SizeT idx = 0; idx < argument_count; idx++) {
Value val = cast_column_vector->GetValue(idx);
in_expression_ptr->TryPut(std::move(val));
}
} else {
Status status = Status::NotSupportedTypeConversion(arguments_type->ToString(), bound_left_expr->Type().ToString());
RecoverableError(status);
}

return in_expression_ptr;
}

Expand Down Expand Up @@ -705,7 +759,7 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildKnnExpr(const KnnExpr &parsed_k
// Create query embedding
EmbeddingT query_embedding((ptr_t)parsed_knn_expr.embedding_data_ptr_, false);

if(parsed_knn_expr.ignore_index_ && !parsed_knn_expr.index_name_.empty()) {
if (parsed_knn_expr.ignore_index_ && !parsed_knn_expr.index_name_.empty()) {
Status status = Status::SyntaxError(fmt::format("Force to use index {} conflicts with Ignore index flag.", parsed_knn_expr.index_name_));
RecoverableError(std::move(status));
}
Expand Down
Loading