Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): implement vector query for sql/redisearch parser & transformer #2450

Merged
merged 12 commits into from
Aug 2, 2024
20 changes: 20 additions & 0 deletions src/search/common_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ struct TreeTransformer {

return result;
}

template <typename T = double>
static StatusOr<std::vector<T>> Binary2Vector(std::string_view str) {
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
if (str.size() % sizeof(T) != 0) {
return {Status::NotOK, "data size is not a multiple of the target type size"};
}

std::vector<T> values;
const size_t type_size = sizeof(T);
values.reserve(str.size() / type_size);

while (!str.empty()) {
T value;
memcpy(&value, str.data(), type_size);
values.push_back(value);
str.remove_prefix(type_size);
}

return values;
}
};

} // namespace kqir
71 changes: 70 additions & 1 deletion src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,63 @@ struct NumericCompareExpr : BoolAtomExpr {
}
};

struct VectorLiteral : Literal {
std::vector<double> values;

explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)){};

std::string_view Name() const override { return "VectorLiteral"; }
std::string Dump() const override {
return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return std::to_string(v); }));
}
std::string Content() const override { return Dump(); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorLiteral>(*this); }
};

struct VectorRangeExpr : BoolAtomExpr {
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> range;
std::unique_ptr<VectorLiteral> vector;

VectorRangeExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&range,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), range(std::move(range)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorRangeExpr"; }
std::string Dump() const override {
return fmt::format("{} <-> {} < {}", field->Dump(), vector->Dump(), range->Dump());
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(range->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct VectorKnnExpr : BoolAtomExpr {
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;

VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorKnnExpr"; }
std::string Dump() const override {
return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump());
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct BoolLiteral : BoolAtomExpr, Literal {
bool val;

Expand Down Expand Up @@ -336,18 +393,30 @@ struct LimitClause : Node {
std::string Content() const override { return fmt::format("{}, {}", offset, count); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<LimitClause>(*this); }
size_t Offset() const { return offset; }

size_t Count() const { return count; }
};

struct SortByClause : Node {
enum Order { ASC, DESC } order = ASC;
std::unique_ptr<FieldRef> field;
std::unique_ptr<VectorLiteral> vector = nullptr;

SortByClause(Order order, std::unique_ptr<FieldRef> &&field) : order(order), field(std::move(field)) {}
SortByClause(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), vector(std::move(vector)) {}

static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; }
bool IsVectorField() const { return vector != nullptr; }

std::string_view Name() const override { return "SortByClause"; }
std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); }
std::string Dump() const override {
if (!IsVectorField()) {
return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order));
}
return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump());
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
}
std::string Content() const override { return OrderToString(order); }

NodeIterator ChildBegin() override { return NodeIterator(field.get()); };
Expand Down
63 changes: 63 additions & 0 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct SemaChecker {
GET_OR_RET(Check(v->query_expr.get()));
if (v->limit) GET_OR_RET(Check(v->limit.get()));
if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"};
}
} else {
return {Status::NotOK, fmt::format("index `{}` not found", index_name)};
}
Expand All @@ -60,8 +63,25 @@ struct SemaChecker {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.IsSortable()) {
return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)};
} else if (auto is_vector = iter->second.MetadataAs<redis::HnswVectorFieldMetadata>() != nullptr;
is_vector != v->IsVectorField()) {
std::string not_str = is_vector ? "" : "not ";
return {Status::NotOK,
fmt::format("field `{}` is {}a vector field according to metadata and does {}expect a vector parameter",
v->field->name, not_str, not_str)};
} else {
v->field->info = &iter->second;
if (v->IsVectorField()) {
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (!v->field->info->HasIndex()) {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
}
} else if (auto v = dynamic_cast<AndExpr *>(node)) {
for (const auto &n : v->inners) {
Expand Down Expand Up @@ -97,6 +117,49 @@ struct SemaChecker {
} else {
v->field->info = &iter->second;
}
} else if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

if (!v->field->info->HasIndex()) {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->k->val <= 0) {
return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")};
}
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
} else if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (meta->distance_metric == redis::DistanceMetric::L2 && v->range->val < 0) {
return {Status::NotOK, "range cannot be a negative number for l2 distance metric"};
}

if (meta->distance_metric == redis::DistanceMetric::COSINE && (v->range->val < 0 || v->range->val > 2)) {
return {Status::NotOK, "range has to be between 0 and 2 for cosine distance metric"};
}

if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto v = dynamic_cast<SelectClause *>(node)) {
for (const auto &n : v->fields) {
if (auto iter = current_index->fields.find(n->name); iter == current_index->fields.end()) {
Expand Down
16 changes: 13 additions & 3 deletions src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ namespace redis_query {

using namespace peg;

struct VectorRangeToken : string<'V', 'E', 'C', 'T', 'O', 'R', '_', 'R', 'A', 'N', 'G', 'E'> {};
struct KnnToken : string<'K', 'N', 'N'> {};
struct ArrowOp : string<'=', '>'> {};
struct Wildcard : one<'*'> {};

struct Field : seq<one<'@'>, Identifier> {};

struct Param : seq<one<'$'>, Identifier> {};
Expand All @@ -44,9 +49,10 @@ struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<TagList, NumericRange>>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct Wildcard : one<'*'> {};
struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};

struct QueryExpr;

Expand All @@ -64,7 +70,11 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct QueryExpr : seq<OrExprP> {};
struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

struct QueryExpr : seq<QueryP> {};

} // namespace redis_query

Expand Down
64 changes: 45 additions & 19 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ namespace redis_query {
namespace ir = kqir;

template <typename Rule>
using TreeSelector =
parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, ExclusiveNumber, FieldQuery, NotExpr,
AndExpr, OrExpr, Wildcard>>;
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeToken, KnnToken, ArrowOp>>;

template <typename Input>
StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
Expand All @@ -53,7 +53,31 @@ StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
struct Transformer : ir::TreeTransformer {
explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) {}

StatusOr<std::unique_ptr<VectorLiteral>> Transform2Vector(const TreeNode& node) {
std::string vector_str = GET_OR_RET(GetParam(node));

std::vector<double> values = GET_OR_RET(Binary2Vector<double>(vector_str));
if (values.empty()) {
return {Status::NotOK, "empty vector is invalid"};
}
return std::make_unique<ir::VectorLiteral>(std::move(values));
};

auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<Number>(node)) {
return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
} else if (Is<Wildcard>(node)) {
Expand Down Expand Up @@ -88,26 +112,12 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::OrExpr>(std::move(exprs));
}
} else { // NumericRange
} else if (Is<NumericRange>(query)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

const auto& lhs = query->children[0];
const auto& rhs = query->children[1];

auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<ExclusiveNumber>(lhs)) {
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
std::make_unique<FieldRef>(field),
Expand Down Expand Up @@ -141,11 +151,27 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::AndExpr>(std::move(exprs));
}
} else if (Is<VectorRange>(query)) {
return std::make_unique<VectorRangeExpr>(std::make_unique<FieldRef>(field),
GET_OR_RET(number_or_param(query->children[1])),
GET_OR_RET(Transform2Vector(query->children[2])));
}
} else if (Is<NotExpr>(node)) {
CHECK(node->children.size() == 1);

return Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));

} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
2 changes: 2 additions & 0 deletions src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {

HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}

bool IsSortable() const override { return true; }

void Encode(std::string *dst) const override {
IndexFieldMetadata::Encode(dst);
PutFixed8(dst, uint8_t(vector_type));
Expand Down
11 changes: 9 additions & 2 deletions src/search/sql_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ struct NumericAtomExpr : WSPad<sor<NumberOrParam, Identifier>> {};
struct NumericCompareOp : sor<string<'!', '='>, string<'<', '='>, string<'>', '='>, one<'=', '<', '>'>> {};
struct NumericCompareExpr : seq<NumericAtomExpr, NumericCompareOp, NumericAtomExpr> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, WSPad<Boolean>> {};
struct VectorCompareOp : string<'<', '-', '>'> {};
struct VectorLiteral : seq<WSPad<one<'['>>, Number, star<seq<WSPad<one<','>>>, Number>, WSPad<one<']'>>> {};
struct VectorCompareExpr : seq<WSPad<Identifier>, VectorCompareOp, WSPad<VectorLiteral>> {};
struct VectorRangeExpr : seq<VectorCompareExpr, one<'<'>, WSPad<NumberOrParam>> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, VectorRangeExpr, WSPad<Boolean>> {};

struct QueryExpr;

Expand Down Expand Up @@ -84,7 +89,9 @@ struct Limit : string<'l', 'i', 'm', 'i', 't'> {};

struct WhereClause : seq<Where, QueryExpr> {};
struct AscOrDesc : sor<Asc, Desc> {};
struct OrderByClause : seq<OrderBy, WSPad<Identifier>, opt<WSPad<AscOrDesc>>> {};
struct SortableFieldExpr : seq<WSPad<Identifier>, opt<AscOrDesc>> {};
struct OrderByExpr : sor<WSPad<VectorCompareExpr>, WSPad<SortableFieldExpr>> {};
struct OrderByClause : seq<OrderBy, OrderByExpr> {};
struct LimitClause : seq<Limit, opt<seq<WSPad<UnsignedInteger>, one<','>>>, WSPad<UnsignedInteger>> {};

struct SearchStmt
Expand Down
Loading
Loading