Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace all mentions of form for inspection #15

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion end_to_end_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def process_llm(ocr_engines: Dict[str, OCR], llm_models: Dict[str, GPT]) -> None
with open(ocr_output_path, "r") as file:
ocr_content = file.read()

llm_result = llm_model.generate_form(ocr_content)
llm_result = llm_model.create_inspection(ocr_content)
save_text_to_file(llm_result, llm_output_path)

logger.info(f"Processed LLM for {ocr_output_file} with {llm_name}")
Expand Down
18 changes: 9 additions & 9 deletions pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .label import LabelStorage # noqa: F401
from .ocr import OCR # noqa: F401
from .form import FertiliserForm # noqa: F401
from .inspection import FertilizerInspection # noqa: F401
from .gpt import GPT # noqa: F401

import os
Expand All @@ -21,7 +21,7 @@ def save_image_to_file(image_bytes: bytes, output_path: str): # pragma: no cover
with open(output_path, 'wb') as output_file:
output_file.write(image_bytes)

def analyze(label_storage: LabelStorage, ocr: OCR, gpt: GPT, log_dir_path: str = './logs') -> FertiliserForm:
def analyze(label_storage: LabelStorage, ocr: OCR, gpt: GPT, log_dir_path: str = './logs') -> FertilizerInspection:
"""
Analyze a fertiliser label using an OCR and an LLM.
It returns the data extracted from the label in a FertiliserForm.
Expand All @@ -37,18 +37,18 @@ def analyze(label_storage: LabelStorage, ocr: OCR, gpt: GPT, log_dir_path: str =
now = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_text_to_file(result.content, f"{log_dir_path}/{now}.md")

# Generate form from extracted text
prediction = gpt.generate_form(result.content)
# Generate inspection from extracted text
prediction = gpt.create_inspection(result.content)

# Logs the results from GPT
save_text_to_file(prediction.form, f"{log_dir_path}/{now}.json")
save_text_to_file(prediction.inspection, f"{log_dir_path}/{now}.json")
save_text_to_file(prediction.rationale, f"{log_dir_path}/{now}.txt")

# Load a JSON from the text
raw_json = json.loads(prediction.form)
raw_json = json.loads(prediction.inspection)

# Check the conformity of the JSON
form = FertiliserForm(**raw_json)
# Check the coninspectionity of the JSON
inspection = FertilizerInspection(**raw_json)

# Clear the label cache
label_storage.clear()
Expand All @@ -58,4 +58,4 @@ def analyze(label_storage: LabelStorage, ocr: OCR, gpt: GPT, log_dir_path: str =
os.remove(f"{log_dir_path}/{now}.txt")
os.remove(f"{log_dir_path}/{now}.json")

return form
return inspection
7 changes: 3 additions & 4 deletions pipeline/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dspy import Prediction
import dspy.adapters
import dspy.utils
from openai.types.chat.completion_create_params import ResponseFormat

MODELS_WITH_RESPONSE_FORMAT = [
"ailab-llm",
Expand Down Expand Up @@ -53,12 +52,12 @@ class ProduceLabelForm(dspy.Signature):
"""
You are a fertilizer label inspector working for the Canadian Food Inspection Agency.
Your task is to classify all information present in the provided text using the specified keys.
Your response should be accurate, intelligible, formatted in JSON, and contain all the text from the provided text.
Your response should be accurate, intelligible, information in JSON, and contain all the text from the provided text.
"""

text = dspy.InputField(desc="The text of the fertilizer label extracted using OCR.")
specification = dspy.InputField(desc="The specification containing the fields to highlight and their requirements.")
form = dspy.OutputField(desc="Only a complete JSON.")
inspection = dspy.OutputField(desc="Only a complete JSON.")

class GPT:
def __init__(self, api_endpoint, api_key, deployment_id):
Expand Down Expand Up @@ -87,7 +86,7 @@ def __init__(self, api_endpoint, api_key, deployment_id):
response_format=response_format,
)

def generate_form(self, prompt) -> Prediction:
def create_inspection(self, prompt) -> Prediction:
with dspy.context(lm=self.dspy_client, experimental=True):
signature = dspy.ChainOfThought(ProduceLabelForm)
prediction = signature(specification=SPECIFICATION, text=prompt)
Expand Down
4 changes: 2 additions & 2 deletions pipeline/form.py → pipeline/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def convert_specification_values(cls, v):
return str(v)
return v

class FertiliserForm(BaseModel):
class FertilizerInspection(BaseModel):
company_name: Optional[str] = None
company_address: Optional[str] = None
company_website: Optional[str] = None
Expand Down Expand Up @@ -86,7 +86,7 @@ def validate_npk(cls, v):
if v is not None:
pattern = re.compile(r'^(\d+(\.\d+)?-\d+(\.\d+)?-\d+(\.\d+)?)?$')
if not pattern.match(v):
raise npkError('npk must be in the format "number-number-number"')
raise npkError('npk must be in the inspectionat "number-number-number"')
return v

@model_validator(mode='before')
Expand Down
18 changes: 9 additions & 9 deletions tests/test_form.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from pydantic import ValidationError
from pipeline import FertiliserForm
from pipeline import FertilizerInspection

class TestFertiliserForm(unittest.TestCase):
def test_valid_fertiliser_form(self):
Expand All @@ -16,7 +16,7 @@ def test_valid_fertiliser_form(self):
}

try:
form = FertiliserForm(**data)
form = FertilizerInspection(**data)
except ValidationError as e:
self.fail(f"Validation error: {e}")

Expand All @@ -29,16 +29,16 @@ def test_valid_fertiliser_form(self):

def test_invalid_npk_format(self):
with self.assertRaises(ValidationError):
FertiliserForm(npk="invalid-format")
FertilizerInspection(npk="invalid-format")

def test_valid_npk_format(self):
try:
FertiliserForm(npk="10.5-20-30")
FertiliserForm(npk="10.5-20.5-30")
FertiliserForm(npk="10.5-0.5-30.1")
FertiliserForm(npk="0-20.5-30.1")
FertiliserForm(npk="0-20.5-1")
FertiliserForm(npk="20.5-1-30.1")
FertilizerInspection(npk="10.5-20-30")
FertilizerInspection(npk="10.5-20.5-30")
FertilizerInspection(npk="10.5-0.5-30.1")
FertilizerInspection(npk="0-20.5-30.1")
FertilizerInspection(npk="0-20.5-1")
FertilizerInspection(npk="20.5-1-30.1")
except ValidationError as e:
self.fail(f"Validation error: {e}")

Expand Down
8 changes: 4 additions & 4 deletions tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json

from dotenv import load_dotenv
from pipeline.form import FertiliserForm
from pipeline.inspection import FertilizerInspection
from pipeline.gpt import GPT
from tests import levenshtein_similarity

Expand Down Expand Up @@ -97,15 +97,15 @@ def check_json(self, extracted_info):
assert key in extracted_info, f"Key '{key}' is missing in the extracted information"

# Check if the json matches the format
FertiliserForm(**expected_json)
FertilizerInspection(**expected_json)

# Check if values match
for key, expected_value in expected_json.items():
assert levenshtein_similarity(str(extracted_info[key]), str(expected_value)) > 0.9, f"Value for key '{key}' does not match. Expected '{expected_value}', got '{extracted_info[key]}'"

def test_generate_form_gpt(self):
prediction = self.gpt.generate_form(self.prompt)
result_json = json.loads(prediction.form)
prediction = self.gpt.create_inspection(self.prompt)
result_json = json.loads(prediction.inspection)
# print(json.dumps(result_json, indent=2))
self.check_json(result_json)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dotenv import load_dotenv
from datetime import datetime
from tests import levenshtein_similarity
from pipeline.form import FertiliserForm, Value
from pipeline.inspection import FertilizerInspection, Value
from pipeline import LabelStorage, OCR, GPT, analyze

class TestPipeline(unittest.TestCase):
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_analyze(self):
form = analyze(self.label_storage, self.ocr, self.gpt, log_dir_path=self.log_dir_path)

# Perform assertions
self.assertIsInstance(form, FertiliserForm)
self.assertIsInstance(form, FertilizerInspection)
self.assertIn(Value(value='25', unit='kg'), form.weight)
self.assertGreater(levenshtein_similarity(form.company_name, "TerraLink"), 0.95)
self.assertGreater(levenshtein_similarity(form.npk, "10-52-0"), 0.90)
Expand Down
Loading