Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
northdpole committed Aug 17, 2024
1 parent ba4d2e5 commit a287d1b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 15 deletions.
22 changes: 15 additions & 7 deletions application/prompt_client/spacy_prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@ class SpacyPromptClient:

def __init__(self) -> None:
try:
self.nlp = spacy.load('en_core_web_sm')
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.info('Downloading language model for the spaCy POS tagger\n' "(don't worry, this will only happen once)")
logger.info(
"Downloading language model for the spaCy POS tagger\n"
"(don't worry, this will only happen once)"
)
from spacy.cli import download
download('en_core_web_sm')
self.nlp = spacy.load('en_core_web_sm')

download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")

def get_text_embeddings(self, text: str):
return self.nlp(text).vector

def create_chat_completion(self, prompt, closest_object_str) -> str:
raise NotImplementedError("Spacy does not support chat completion you need to set up a different client if you need this functionality")

raise NotImplementedError(
"Spacy does not support chat completion you need to set up a different client if you need this functionality"
)

def query_llm(self, raw_question: str) -> str:
raise NotImplementedError("Spacy does not support chat completion you need to set up a different client if you need this functionality")
raise NotImplementedError(
"Spacy does not support chat completion you need to set up a different client if you need this functionality"
)
5 changes: 4 additions & 1 deletion application/tests/spreadsheet_parsers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,14 @@ def test_suggest_from_export_format(self) -> None:
if len(cres_in_line) == 0:
empty_lines += 1

self.assertGreater(len(input_data)/2,empty_lines) # assert that there was at least some suggestions
self.assertGreater(
len(input_data) / 2, empty_lines
) # assert that there was at least some suggestions

sqla.session.remove()
sqla.drop_all()
self.app_context.pop()


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions application/tests/web_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ def test_get_cre_csv(self) -> None:
data.getvalue(),
response.data.decode(),
)

def test_suggest_from_cre_csv(self) -> None:
# empty string means temporary db
# self.app = create_app(mode="test")
Expand All @@ -954,7 +955,7 @@ def test_suggest_from_cre_csv(self) -> None:
[no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")]
index += 1
input_data_no_cres.append(no_cre_line)

workspace = tempfile.mkdtemp()
data = {}
with open(os.path.join(workspace, "cre.csv"), "w") as f:
Expand All @@ -963,7 +964,7 @@ def test_suggest_from_cre_csv(self) -> None:
cdw.writerows(input_data_no_cres)

data["cre_csv"] = open(os.path.join(workspace, "cre.csv"), "rb")

with self.app.test_client() as client:
response = client.post(
"/rest/v1/cre_csv/suggest",
Expand All @@ -983,5 +984,4 @@ def test_suggest_from_cre_csv(self) -> None:
]
if len(cres_in_line) == 0:
empty_lines += 1
self.assertGreater(len(input_data_no_cres)/2,empty_lines)

self.assertGreater(len(input_data_no_cres) / 2, empty_lines)
3 changes: 2 additions & 1 deletion application/utils/spreadsheet_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def suggest_from_export_format(
if any(
[
entry.startswith("CRE ")
for entry,value in line.items() if not is_empty(value)
for entry, value in line.items()
if not is_empty(value)
]
): # we found a mapping in the line, no need to do anything, flush to buffer
output.append(line)
Expand Down
7 changes: 5 additions & 2 deletions application/web/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def import_from_cre_csv() -> Any:
}
)


@app.route("/rest/v1/cre_csv/suggest", methods=["POST"])
def suggest_from_cre_csv() -> Any:
"""Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions.
Expand All @@ -769,12 +770,14 @@ def suggest_from_cre_csv() -> Any:
"""
database = db.Node_collection()
file = request.files.get("cre_csv")

if file is None:
abort(400, "No file provided")
contents = file.read()
csv_read = csv.DictReader(contents.decode("utf-8").splitlines())
response = spreadsheet_parsers.suggest_from_export_format(list(csv_read),database=database)
response = spreadsheet_parsers.suggest_from_export_format(
list(csv_read), database=database
)
csvVal = write_csv(docs=response).getvalue().encode("utf-8")

# Creating the byteIO object from the StringIO Object
Expand Down

0 comments on commit a287d1b

Please sign in to comment.