Skip to content

Commit

Permalink
implement cre suggestions from OpenCRE import sheet
Browse files Browse the repository at this point in the history
  • Loading branch information
northdpole committed Aug 17, 2024
1 parent 73f5ea4 commit ba4d2e5
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 3 deletions.
4 changes: 3 additions & 1 deletion application/prompt_client/prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@


def is_valid_url(url):
if not url:
return False
return url.startswith("http://") or url.startswith("https://")


Expand Down Expand Up @@ -154,7 +156,7 @@ def generate_embeddings(
logger.info(f"generating {len(missing_embeddings)} embeddings")
for id in missing_embeddings:
cre = database.get_cre_by_db_id(id)
node = database.get_nodes(db_id=id)
node = database.get_nodes(db_id=id)[0]
content = ""
if node:
if is_valid_url(node.hyperlink):
Expand Down
41 changes: 41 additions & 0 deletions application/tests/spreadsheet_parsers_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from pprint import pprint
import unittest
from application.database import db
from application.tests.utils import data_gen
from application.defs import cre_defs as defs
from application import create_app, sqla # type: ignore
from application.utils.spreadsheet_parsers import (
parse_export_format,
parse_hierarchical_export_format,
suggest_from_export_format,
)


Expand Down Expand Up @@ -37,6 +40,44 @@ def test_parse_hierarchical_export_format(self) -> None:
for element in v:
self.assertIn(element, output[k])

def test_suggest_from_export_format(self) -> None:
self.app = create_app(mode="test")
self.app_context = self.app.app_context()
self.app_context.push()
sqla.create_all()
collection = db.Node_collection()

input_data, expected_output = data_gen.export_format_data()
for cre in expected_output[defs.Credoctypes.CRE.value]:
collection.add_cre(cre=cre)

# clean every other cre
index = 0
input_data_no_cres = []
for line in input_data:
no_cre_line = line.copy()
if index % 2 == 0:
[no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")]
index += 1
input_data_no_cres.append(no_cre_line)
output = suggest_from_export_format(
lfile=input_data_no_cres, database=collection
)
self.maxDiff = None

empty_lines = 0
for line in output:
cres_in_line = [
line[c] for c in line.keys() if c.startswith("CRE") and line[c]
]
if len(cres_in_line) == 0:
empty_lines += 1

self.assertGreater(len(input_data)/2,empty_lines) # assert that there was at least some suggestions

sqla.session.remove()
sqla.drop_all()
self.app_context.pop()

if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions application/tests/web_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,57 @@ def test_get_cre_csv(self) -> None:
data.getvalue(),
response.data.decode(),
)
def test_suggest_from_cre_csv(self) -> None:
# empty string means temporary db
# self.app = create_app(mode="test")
# self.app_context = self.app.app_context()
# self.app_context.push()
# sqla.create_all()
collection = db.Node_collection()

input_data, expected_output = data_gen.export_format_data()
for cre in expected_output[defs.Credoctypes.CRE.value]:
collection.add_cre(cre=cre)

# clean every other cre
index = 0
input_data_no_cres = []
keys = {}
for line in input_data:
keys.update(line)
no_cre_line = line.copy()
if index % 2 == 0:
[no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")]
index += 1
input_data_no_cres.append(no_cre_line)

workspace = tempfile.mkdtemp()
data = {}
with open(os.path.join(workspace, "cre.csv"), "w") as f:
cdw = csv.DictWriter(f, fieldnames=keys.keys())
cdw.writeheader()
cdw.writerows(input_data_no_cres)

data["cre_csv"] = open(os.path.join(workspace, "cre.csv"), "rb")

with self.app.test_client() as client:
response = client.post(
"/rest/v1/cre_csv/suggest",
data=data,
buffered=True,
content_type="multipart/form-data",
)
self.assertEqual(200, response.status_code)
empty_lines = 0

pprint(response.data.decode())
input()

for line in json.loads(response.data.decode()):
cres_in_line = [
line[c] for c in line.keys() if c.startswith("CRE") and line[c]
]
if len(cres_in_line) == 0:
empty_lines += 1
self.assertGreater(len(input_data_no_cres)/2,empty_lines)

60 changes: 59 additions & 1 deletion application/utils/spreadsheet_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from copy import copy
from typing import Any, Dict, List, Optional
from dataclasses import dataclass

from application.prompt_client import prompt_client
from application.defs import cre_defs as defs
from application.database import db

# collection of methods to parse different versions of spreadsheet standards
# each method returns a list of cre_defs documents
Expand Down Expand Up @@ -567,3 +568,60 @@ def parse_standards(
)
)
return links


def suggest_from_export_format(
lfile: List[Dict[str, Any]], database: db.Node_collection
) -> Dict[str, Any]:
output: List[Dict[str, Any]] = []
for line in lfile:
standard: defs.Node = None
if any(
[
entry.startswith("CRE ")
for entry,value in line.items() if not is_empty(value)
]
): # we found a mapping in the line, no need to do anything, flush to buffer
output.append(line)
break
for entry, value in line.items():
if entry.startswith("CRE "):
continue # we established above there are no CRE entries in this line

if not is_empty(value):
standard_name = entry.split("|")[0]
standard = defs.Standard(
name=standard_name,
sectionID=line.get(
f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.id}"
),
section=line.get(
f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.section}"
),
hyperlink=line.get(
f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.hyperlink}"
),
description=line.get(
f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.description}"
),
)
# find nearest CRE for standards in line
ph = prompt_client.PromptHandler(database=database, load_all_embeddings=False)

most_similar_id, _ = ph.get_id_of_most_similar_cre_paginated(
item_embedding=ph.generate_embeddings_for_document(standard)
)
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

cre = database.get_cre_by_db_id(most_similar_id)
if not cre:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
line[f"CRE 0"] = f"{cre.id}{defs.ExportFormat.separator}{cre.name}"
# add it to the line
output.append(line)
return output
30 changes: 30 additions & 0 deletions application/web/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def get_cre_csv() -> Any:


@app.route("/rest/v1/cre_csv_import", methods=["POST"])
@app.route("/rest/v1/cre_csv/import", methods=["POST"])
def import_from_cre_csv() -> Any:
if not os.environ.get("CRE_ALLOW_IMPORT"):
abort(
Expand Down Expand Up @@ -759,6 +760,35 @@ def import_from_cre_csv() -> Any:
}
)

@app.route("/rest/v1/cre_csv/suggest", methods=["POST"])
def suggest_from_cre_csv() -> Any:
"""Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions.
Returns:
Any: the csv file with the missing fields filled in with suggestions
"""
database = db.Node_collection()
file = request.files.get("cre_csv")

if file is None:
abort(400, "No file provided")
contents = file.read()
csv_read = csv.DictReader(contents.decode("utf-8").splitlines())
response = spreadsheet_parsers.suggest_from_export_format(list(csv_read),database=database)
csvVal = write_csv(docs=response).getvalue().encode("utf-8")

# Creating the byteIO object from the StringIO Object
mem = io.BytesIO()
mem.write(csvVal)
mem.seek(0)

return send_file(
mem,
as_attachment=True,
download_name="CRE-Catalogue.csv",
mimetype="text/csv",
)


# /End Importing Handlers

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,5 @@ urllib3
vertexai
xmltodict
google-cloud-trace
alive-progress
alive-progress
spacy

0 comments on commit ba4d2e5

Please sign in to comment.