Skip to content

Commit

Permalink
Implement rudimentary ngram-based search with item updates and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Nov 10, 2024
1 parent 6531370 commit dbf241f
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 28 deletions.
1 change: 1 addition & 0 deletions pydatalab/src/pydatalab/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def create_app(
extension.init_app(app)

pydatalab.mongo.create_default_indices()
pydatalab.mongo.create_ngram_item_index()

if CONFIG.FILE_DIRECTORY is not None:
pathlib.Path(CONFIG.FILE_DIRECTORY).mkdir(parents=False, exist_ok=True)
Expand Down
34 changes: 18 additions & 16 deletions pydatalab/src/pydatalab/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"flask_mongo",
"check_mongo_connection",
"create_default_indices",
"create_ngram_item_index",
"_get_active_mongo_client",
"insert_pydantic_model_fork_safe",
"ITEMS_FTS_FIELDS",
Expand Down Expand Up @@ -204,27 +205,20 @@ def create_ngram_item_index(
):
from bson import ObjectId

from pydatalab.models import ITEM_MODELS

if client is None:
client = _get_active_mongo_client()
db = client.get_database()

item_fts_fields = set()
for model in ITEM_MODELS:
schema = ITEM_MODELS[model].schema()
for f in schema["properties"]:
if schema["properties"][f].get("type") == "string":
item_fts_fields.add(f)

# construct manual ngram index
ngram_index: dict[ObjectId, set[str]] = {}
type_index: dict[ObjectId, str] = {}
item_count: int = 0
global_ngram_count: dict[str, int] = collections.defaultdict(int)
for item in db.items.find({}):
item_count += 1
ngrams: dict[str, int] = _generate_item_ngrams(item, item_fts_fields)
ngrams: dict[str, int] = _generate_item_ngrams(item, ITEM_FTS_FIELDS)
ngram_index[item["_id"]] = set(ngrams)
type_index[item["_id"]] = item["type"]
for g in ngrams:
global_ngram_count[g] += ngrams[g]

Expand All @@ -235,8 +229,12 @@ def create_ngram_item_index(
# for item in ngram_index:
# ngram_index[item].pop(ngram)

for _id, item in ngram_index.items():
db.items_fts.update_one({"_id": _id}, {"$set": {"_fts_ngrams": item}})
for _id, _ngrams in ngram_index.items():
db.items_fts.update_one(
{"_id": _id},
{"$set": {"type": type_index[_id], "_fts_ngrams": list(_ngrams)}},
upsert=True,
)

try:
result = db.items_fts.create_index(
Expand All @@ -260,7 +258,7 @@ def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]:

ngrams: dict[str, int] = collections.defaultdict(int)

if len(value) < n:
if not value or len(value) < n:
return ngrams

# first, tokenize by whitespace and punctuation (a la normal mongodb fts)
Expand All @@ -279,8 +277,12 @@ def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]:
def _generate_item_ngrams(item: dict, fts_fields: set[str], n: int = 3):
ngrams: dict[str, int] = collections.defaultdict(int)
for field in fts_fields:
field_ngrams = _generate_ngrams(item.get(field, None))
for k in field_ngrams:
ngrams[k] += field_ngrams[k]
value = item.get(field, None)
if value:
if field == "refcode" and ":" in value:
value = value.split(":")[1]
field_ngrams = _generate_ngrams(value)
for k in field_ngrams:
ngrams[k] += field_ngrams[k]

return ngrams
53 changes: 42 additions & 11 deletions pydatalab/src/pydatalab/routes/v0_1/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydatalab.models.items import Item
from pydatalab.models.relationships import RelationshipType
from pydatalab.models.utils import generate_unique_refcode
from pydatalab.mongo import flask_mongo
from pydatalab.mongo import ITEM_FTS_FIELDS, _generate_item_ngrams, flask_mongo
from pydatalab.permissions import PUBLIC_USER_ID, active_users_or_get_only, get_default_permissions

ITEMS = Blueprint("items", __name__)
Expand Down Expand Up @@ -306,23 +306,33 @@ def search_items_ngram():
types = types.split(",")

# split search string into trigrams
query = query.lower()
if len(query) < 3:
trigrams = [query]
trigrams = [query[i:i+3] for i in range(len(query)-2)]
trigrams = [query[i : i + 3] for i in range(len(query) - 2)]

match_obj = {
"_fts_trigrams": {"$in": trigrams},
"_fts_ngrams": {"$in": trigrams},
**get_default_permissions(user_only=False),
}

if types is not None:
match_obj["type"] = {"$in": types}

cursor = flask_mongo.db.items.aggregate(
cursor = flask_mongo.db.items_fts.aggregate(
[
{"$match": match_obj},
{"$sort": {"score": {"$meta": "textScore"}}},
{"$limit": nresults},
{
"$lookup": {
"from": "items",
"localField": "_id",
"foreignField": "_id",
"as": "items",
}
},
{"$unwind": "$items"},
{"$replaceRoot": {"newRoot": {"$mergeObjects": ["$items"]}}},
{
"$project": {
"_id": 0,
Expand Down Expand Up @@ -567,6 +577,16 @@ def _create_sample(
400,
)

# Update ngram index, if configured
ngrams = _generate_item_ngrams(
flask_mongo.db.items.find_one(result.inserted_id), ITEM_FTS_FIELDS
)
flask_mongo.db.items_fts.update_one(
{"_id": result.inserted_id},
{"$set": {"type": data_model.type, "_fts_ngrams": list(ngrams)}},
upsert=True,
)

sample_list_entry = {
"refcode": data_model.refcode,
"item_id": data_model.item_id,
Expand Down Expand Up @@ -664,11 +684,11 @@ def delete_sample():
request_json = request.get_json() # noqa: F821 pylint: disable=undefined-variable
item_id = request_json["item_id"]

result = flask_mongo.db.items.delete_one(
{"item_id": item_id, **get_default_permissions(user_only=True)}
deleted_doc = flask_mongo.db.items.find_one_and_delete(
{"item_id": item_id, **get_default_permissions(user_only=True)}, projection={"_id": 1}
)

if result.deleted_count != 1:
if deleted_doc is None:
return (
jsonify(
{
Expand All @@ -678,6 +698,10 @@ def delete_sample():
),
401,
)

# Update ngram index, if configured
flask_mongo.db.items_fts.delete_one({"_id": deleted_doc["_id"]})

return (
jsonify(
{
Expand Down Expand Up @@ -926,21 +950,28 @@ def save_item():
item.pop("collections")
item.pop("creators")

result = flask_mongo.db.items.update_one(
updated_doc = flask_mongo.db.items.find_one_and_update(
{"item_id": item_id},
{"$set": item},
)

if result.matched_count != 1:
if updated_doc is None:
return (
jsonify(
status="error",
message=f"{item_id} item update failed. no subdocument matched",
output=result.raw_result,
),
400,
)

# Update ngram index, if configured
ngrams = _generate_item_ngrams(updated_doc, ITEM_FTS_FIELDS)
flask_mongo.db.items_fts.update_one(
{"_id": updated_doc["_id"]},
{"$set": {"type": updated_doc["type"], "_fts_ngrams": list(ngrams)}},
upsert=True,
)

return jsonify(status="success", last_modified=updated_data["last_modified"]), 200


Expand Down
56 changes: 55 additions & 1 deletion pydatalab/tests/server/test_ngram_fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams
from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams, create_ngram_item_index


def test_ngram_single_field():
Expand Down Expand Up @@ -49,3 +49,57 @@ def test_ngram_single_field():
def test_ngram_item():
item = {"refcode": "ABCDEF"}
assert _generate_item_ngrams(item, {"refcode"}, n=3) == {"abc": 1, "bcd": 1, "cde": 1, "def": 1}


def test_ngram_fts_route(client, default_sample_dict, real_mongo_client, database):
default_sample_dict["item_id"] = "ABCDEF"
response = client.post("/new-sample/", json=default_sample_dict)
assert response.status_code == 201

# Check that creating the ngram index with existing items works
create_ngram_item_index(real_mongo_client, background=False, filter_top_ngrams=None)

doc = database.items_fts.find_one({})
ngrams = set(doc["_fts_ngrams"])
for ng in ["abc", "bcd", "cde", "def", "sam", "ple"]:
assert ng in ngrams
assert doc["type"] == "samples"

query_strings = ("ABC", "ABCDEF", "abcd", "cdef")

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 1
assert response.json["items"][0]["item_id"] == "ABCDEF"

# Check that new items are added to the ngram index
default_sample_dict["item_id"] = "ABCDEF2"
response = client.post("/new-sample/", json=default_sample_dict)
assert response.status_code == 201

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 2
assert response.json["items"][0]["item_id"] == "ABCDEF"
assert response.json["items"][1]["item_id"] == "ABCDEF2"

# Check that updates are reflected in the ngram index
# This test also makes sure that the string 'test' is not picked up from the refcode,
# which has an explicit carve out
default_sample_dict["description"] = "test string with punctuation"
update_req = {"item_id": "ABCDEF2", "data": default_sample_dict}
response = client.post("/save-item/", json=update_req)
assert response.status_code == 200

query_strings = ("test", "punctuation")

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 1
assert response.json["items"][0]["item_id"] == "ABCDEF2"

0 comments on commit dbf241f

Please sign in to comment.