From dbf241f7542a4a4db9b3a5ec2eeb754c539675c2 Mon Sep 17 00:00:00 2001 From: Matthew Evans Date: Thu, 7 Nov 2024 10:33:56 +0000 Subject: [PATCH] Implement rudimentary ngram-based search with item updates and add tests --- pydatalab/src/pydatalab/main.py | 1 + pydatalab/src/pydatalab/mongo.py | 34 ++++++------ pydatalab/src/pydatalab/routes/v0_1/items.py | 53 ++++++++++++++---- pydatalab/tests/server/test_ngram_fts.py | 56 +++++++++++++++++++- 4 files changed, 116 insertions(+), 28 deletions(-) diff --git a/pydatalab/src/pydatalab/main.py b/pydatalab/src/pydatalab/main.py index 0506953a3..17b83e2a2 100644 --- a/pydatalab/src/pydatalab/main.py +++ b/pydatalab/src/pydatalab/main.py @@ -206,6 +206,7 @@ def create_app( extension.init_app(app) pydatalab.mongo.create_default_indices() + pydatalab.mongo.create_ngram_item_index() if CONFIG.FILE_DIRECTORY is not None: pathlib.Path(CONFIG.FILE_DIRECTORY).mkdir(parents=False, exist_ok=True) diff --git a/pydatalab/src/pydatalab/mongo.py b/pydatalab/src/pydatalab/mongo.py index f7c58e87a..88b54c881 100644 --- a/pydatalab/src/pydatalab/mongo.py +++ b/pydatalab/src/pydatalab/mongo.py @@ -13,6 +13,7 @@ "flask_mongo", "check_mongo_connection", "create_default_indices", + "create_ngram_item_index", "_get_active_mongo_client", "insert_pydantic_model_fork_safe", "ITEMS_FTS_FIELDS", @@ -204,27 +205,20 @@ def create_ngram_item_index( ): from bson import ObjectId - from pydatalab.models import ITEM_MODELS - if client is None: client = _get_active_mongo_client() db = client.get_database() - item_fts_fields = set() - for model in ITEM_MODELS: - schema = ITEM_MODELS[model].schema() - for f in schema["properties"]: - if schema["properties"][f].get("type") == "string": - item_fts_fields.add(f) - # construct manual ngram index ngram_index: dict[ObjectId, set[str]] = {} + type_index: dict[ObjectId, str] = {} item_count: int = 0 global_ngram_count: dict[str, int] = collections.defaultdict(int) for item in db.items.find({}): item_count += 1 - ngrams: dict[str, int] = _generate_item_ngrams(item, item_fts_fields) + ngrams: dict[str, int] = _generate_item_ngrams(item, ITEM_FTS_FIELDS) ngram_index[item["_id"]] = set(ngrams) + type_index[item["_id"]] = item["type"] for g in ngrams: global_ngram_count[g] += ngrams[g] @@ -235,8 +229,12 @@ def create_ngram_item_index( # for item in ngram_index: # ngram_index[item].pop(ngram) - for _id, item in ngram_index.items(): - db.items_fts.update_one({"_id": _id}, {"$set": {"_fts_ngrams": item}}) + for _id, _ngrams in ngram_index.items(): + db.items_fts.update_one( + {"_id": _id}, + {"$set": {"type": type_index[_id], "_fts_ngrams": list(_ngrams)}}, + upsert=True, + ) try: result = db.items_fts.create_index( @@ -260,7 +258,7 @@ def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]: ngrams: dict[str, int] = collections.defaultdict(int) - if len(value) < n: + if not value or len(value) < n: return ngrams # first, tokenize by whitespace and punctuation (a la normal mongodb fts) @@ -279,8 +277,12 @@ def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]: def _generate_item_ngrams(item: dict, fts_fields: set[str], n: int = 3): ngrams: dict[str, int] = collections.defaultdict(int) for field in fts_fields: - field_ngrams = _generate_ngrams(item.get(field, None)) - for k in field_ngrams: - ngrams[k] += field_ngrams[k] + value = item.get(field, None) + if value: + if field == "refcode" and ":" in value: + value = value.split(":")[1] + field_ngrams = _generate_ngrams(value) + for k in field_ngrams: + ngrams[k] += field_ngrams[k] return ngrams diff --git a/pydatalab/src/pydatalab/routes/v0_1/items.py b/pydatalab/src/pydatalab/routes/v0_1/items.py index 74f0738fe..915de7b2c 100644 --- a/pydatalab/src/pydatalab/routes/v0_1/items.py +++ b/pydatalab/src/pydatalab/routes/v0_1/items.py @@ -15,7 +15,7 @@ from pydatalab.models.items import Item from pydatalab.models.relationships import RelationshipType from pydatalab.models.utils import generate_unique_refcode -from pydatalab.mongo import flask_mongo +from pydatalab.mongo import ITEM_FTS_FIELDS, _generate_item_ngrams, flask_mongo from pydatalab.permissions import PUBLIC_USER_ID, active_users_or_get_only, get_default_permissions ITEMS = Blueprint("items", __name__) @@ -306,23 +306,33 @@ def search_items_ngram(): types = types.split(",") # split search string into trigrams + query = query.lower() if len(query) < 3: trigrams = [query] - trigrams = [query[i:i+3] for i in range(len(query)-2)] + trigrams = [query[i : i + 3] for i in range(len(query) - 2)] match_obj = { - "_fts_trigrams": {"$in": trigrams}, + "_fts_ngrams": {"$in": trigrams}, **get_default_permissions(user_only=False), } if types is not None: match_obj["type"] = {"$in": types} - cursor = flask_mongo.db.items.aggregate( + cursor = flask_mongo.db.items_fts.aggregate( [ {"$match": match_obj}, - {"$sort": {"score": {"$meta": "textScore"}}}, {"$limit": nresults}, + { + "$lookup": { + "from": "items", + "localField": "_id", + "foreignField": "_id", + "as": "items", + } + }, + {"$unwind": "$items"}, + {"$replaceRoot": {"newRoot": {"$mergeObjects": ["$items"]}}}, { "$project": { "_id": 0, @@ -567,6 +577,16 @@ def _create_sample( 400, ) + # Update ngram index, if configured + ngrams = _generate_item_ngrams( + flask_mongo.db.items.find_one(result.inserted_id), ITEM_FTS_FIELDS + ) + flask_mongo.db.items_fts.update_one( + {"_id": result.inserted_id}, + {"$set": {"type": data_model.type, "_fts_ngrams": list(ngrams)}}, + upsert=True, + ) + sample_list_entry = { "refcode": data_model.refcode, "item_id": data_model.item_id, @@ -664,11 +684,11 @@ def delete_sample(): request_json = request.get_json() # noqa: F821 pylint: disable=undefined-variable item_id = request_json["item_id"] - result = flask_mongo.db.items.delete_one( - {"item_id": item_id, **get_default_permissions(user_only=True)} + deleted_doc = flask_mongo.db.items.find_one_and_delete( + {"item_id": item_id, **get_default_permissions(user_only=True)}, projection={"_id": 1} ) - if result.deleted_count != 1: + if deleted_doc is None: return ( jsonify( { @@ -678,6 +698,10 @@ def delete_sample(): ), 401, ) + + # Update ngram index, if configured + flask_mongo.db.items_fts.delete_one({"_id": deleted_doc["_id"]}) + return ( jsonify( { @@ -926,21 +950,28 @@ def save_item(): item.pop("collections") item.pop("creators") - result = flask_mongo.db.items.update_one( + updated_doc = flask_mongo.db.items.find_one_and_update( {"item_id": item_id}, {"$set": item}, ) - if result.matched_count != 1: + if updated_doc is None: return ( jsonify( status="error", message=f"{item_id} item update failed. no subdocument matched", - output=result.raw_result, ), 400, ) + # Update ngram index, if configured + ngrams = _generate_item_ngrams(updated_doc, ITEM_FTS_FIELDS) + flask_mongo.db.items_fts.update_one( + {"_id": updated_doc["_id"]}, + {"$set": {"type": updated_doc["type"], "_fts_ngrams": list(ngrams)}}, + upsert=True, + ) + return jsonify(status="success", last_modified=updated_data["last_modified"]), 200 diff --git a/pydatalab/tests/server/test_ngram_fts.py b/pydatalab/tests/server/test_ngram_fts.py index 65e47604f..f1a3e9c63 100644 --- a/pydatalab/tests/server/test_ngram_fts.py +++ b/pydatalab/tests/server/test_ngram_fts.py @@ -3,7 +3,7 @@ """ -from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams +from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams, create_ngram_item_index def test_ngram_single_field(): @@ -49,3 +49,57 @@ def test_ngram_single_field(): def test_ngram_item(): item = {"refcode": "ABCDEF"} assert _generate_item_ngrams(item, {"refcode"}, n=3) == {"abc": 1, "bcd": 1, "cde": 1, "def": 1} + + +def test_ngram_fts_route(client, default_sample_dict, real_mongo_client, database): + default_sample_dict["item_id"] = "ABCDEF" + response = client.post("/new-sample/", json=default_sample_dict) + assert response.status_code == 201 + + # Check that creating the ngram index with existing items works + create_ngram_item_index(real_mongo_client, background=False, filter_top_ngrams=None) + + doc = database.items_fts.find_one({}) + ngrams = set(doc["_fts_ngrams"]) + for ng in ["abc", "bcd", "cde", "def", "sam", "ple"]: + assert ng in ngrams + assert doc["type"] == "samples" + + query_strings = ("ABC", "ABCDEF", "abcd", "cdef") + + for q in query_strings: + response = client.get(f"/search-items-ngram/?query={q}&types=samples") + assert response.status_code == 200 + assert response.json["status"] == "success" + assert len(response.json["items"]) == 1 + assert response.json["items"][0]["item_id"] == "ABCDEF" + + # Check that new items are added to the ngram index + default_sample_dict["item_id"] = "ABCDEF2" + response = client.post("/new-sample/", json=default_sample_dict) + assert response.status_code == 201 + + for q in query_strings: + response = client.get(f"/search-items-ngram/?query={q}&types=samples") + assert response.status_code == 200 + assert response.json["status"] == "success" + assert len(response.json["items"]) == 2 + assert response.json["items"][0]["item_id"] == "ABCDEF" + assert response.json["items"][1]["item_id"] == "ABCDEF2" + + # Check that updates are reflected in the ngram index + # This test also makes sure that the string 'test' is not picked up from the refcode, + # which has an explicit carve out + default_sample_dict["description"] = "test string with punctuation" + update_req = {"item_id": "ABCDEF2", "data": default_sample_dict} + response = client.post("/save-item/", json=update_req) + assert response.status_code == 200 + + query_strings = ("test", "punctuation") + + for q in query_strings: + response = client.get(f"/search-items-ngram/?query={q}&types=samples") + assert response.status_code == 200 + assert response.json["status"] == "success" + assert len(response.json["items"]) == 1 + assert response.json["items"][0]["item_id"] == "ABCDEF2"