Skip to content

Commit

Permalink
Refactor prompt_parser (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala authored Sep 20, 2023
1 parent b01a7f8 commit 72d97cc
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 69 deletions.
75 changes: 7 additions & 68 deletions prompt2model/prompt_parser/instr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@

from __future__ import annotations # noqa FI58

import json
import os

import openai

from prompt2model.prompt_parser.base import PromptSpec, TaskType

from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split
construct_prompt_for_instruction_parsing,
)

from prompt2model.utils import api_tools, get_formatted_logger
from prompt2model.utils.api_tools import API_ERRORS, handle_api_error

logger = get_formatted_logger("PromptParser")
from prompt2model.utils.parse_json_responses import parse_prompt_to_fields

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand All @@ -38,40 +32,7 @@ def __init__(self, task_type: TaskType, max_api_calls: int = None):
self.task_type = task_type
self._instruction: str | None = None
self._examples: str | None = None
if max_api_calls and max_api_calls <= 0:
raise ValueError("max_api_calls must be > 0.")
self.max_api_calls = max_api_calls
self.api_call_counter = 0

def extract_response(self, response: openai.Completion) -> tuple[str, str] | None:
"""Parse stuctured fields from the API response.
Args:
response: API response.
Returns:
If the API response is a valid JSON object and contains the required_keys,
then returns a tuple consisting of:
1) Instruction: The instruction parsed from the API response.
2) Demonstrations: (Optional) demonstrations parsed from the
API response.
Else returns None.
"""
response_text = response.choices[0]["message"]["content"]
try:
response_json = json.loads(response_text, strict=False)
except json.decoder.JSONDecodeError:
logger.warning(f"API response was not a valid JSON: {response_text}")
return None

required_keys = ["Instruction", "Demonstrations"]
missing_keys = [key for key in required_keys if key not in response_json]
if len(missing_keys) != 0:
logger.warning(f'API response must contain {", ".join(required_keys)} keys')
return None
instruction_string = response_json["Instruction"].strip()
demonstration_string = response_json["Demonstrations"].strip()
return instruction_string, demonstration_string

def parse_from_prompt(self, prompt: str) -> None:
"""Parse prompt into specific fields, stored as class member variables.
Expand All @@ -84,31 +45,9 @@ def parse_from_prompt(self, prompt: str) -> None:
"instruction" and "demonstrations".
"""
parsing_prompt_for_chatgpt = construct_prompt_for_instruction_parsing(prompt)

chat_api = api_tools.default_api_agent
last_error = None
while True:
self.api_call_counter += 1
try:
response: openai.ChatCompletion | Exception = (
chat_api.generate_one_completion(
parsing_prompt_for_chatgpt,
temperature=0.01,
presence_penalty=0,
frequency_penalty=0,
)
)
extraction = self.extract_response(response)
if extraction is not None:
self._instruction, self._examples = extraction
return
except API_ERRORS as e:
last_error = e
handle_api_error(e)

if self.max_api_calls and self.api_call_counter >= self.max_api_calls:
# In case we reach maximum number of API calls, we raise an error.
logger.error("Maximum number of API calls reached.")
raise RuntimeError(
"Maximum number of API calls reached."
) from last_error
required_keys = ["Instruction", "Demonstrations"]
extraction = parse_prompt_to_fields(
parsing_prompt_for_chatgpt, required_keys, max_api_calls=self.max_api_calls
)
self._instruction = extraction["Instruction"]
self._examples = extraction["Demonstrations"]
96 changes: 96 additions & 0 deletions prompt2model/utils/parse_json_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Utility file for parsing OpenAI json responses."""
from __future__ import annotations

import json

import openai

from prompt2model.utils import api_tools, get_formatted_logger
from prompt2model.utils.api_tools import API_ERRORS, handle_api_error

logger = get_formatted_logger("ParseJsonResponses")


def extract_response(
response: openai.Completion, required_keys: list, optional_keys: list
) -> dict | None:
"""Parse stuctured fields from the API response.
Args:
response: API response.
required_keys: Required keys from the response
optional_keys: Optional keys from the response
Returns:
If the API response is a valid JSON object and contains the
required and optional keys then returns the
final response as a Dictionary
Else returns None.
"""
response_text = response.choices[0]["message"]["content"]
try:
response_json = json.loads(response_text, strict=False)
except json.decoder.JSONDecodeError:
logger.warning(f"API response was not a valid JSON: {response_text}")
return None

missing_keys = [key for key in required_keys if key not in response_json]
if len(missing_keys) != 0:
logger.warning(f'API response must contain {", ".join(required_keys)} keys')
return None

final_response = {key: response_json[key].strip() for key in required_keys}
optional_response = {
key: response_json[key].strip() for key in optional_keys if key in response_json
}
final_response.update(optional_response)
return final_response


def parse_prompt_to_fields(
prompt: str,
required_keys: list,
optional_keys: list = [],
max_api_calls: int = None,
) -> dict:
"""Parse prompt into specific fields, and return to the calling function.
This function calls the required api, has the logic for the retrying,
passes the response to the parsing function, and return the
response back or throws an error
Args:
prompt: User prompt into specific fields
required_keys: Fields that need to be present in the response
optional_keys: Field that may/may not be present in the response
Returns:
Parsed Response or throws error
"""
chat_api = api_tools.default_api_agent
if max_api_calls and max_api_calls <= 0:
raise ValueError("max_api_calls must be > 0.")
api_call_counter = 0
last_error = None
while True:
api_call_counter += 1
try:
response: openai.ChatCompletion | Exception = (
chat_api.generate_one_completion(
prompt,
temperature=0.01,
presence_penalty=0,
frequency_penalty=0,
)
)
extraction = extract_response(response, required_keys, optional_keys)
if extraction is not None:
return extraction
except API_ERRORS as e:
last_error = e
handle_api_error(e)

if max_api_calls and api_call_counter >= max_api_calls:
# In case we reach maximum number of API calls, we raise an error.
logger.error("Maximum number of API calls reached.")
raise RuntimeError("Maximum number of API calls reached.") from last_error
2 changes: 1 addition & 1 deletion tests/prompt_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from test_helpers.mock_api import MockAPIAgent
from test_helpers.test_utils import temp_setattr

logger = logging.getLogger("PromptParser")
logger = logging.getLogger("ParseJsonResponses")
GPT3_RESPONSE_WITH_DEMONSTRATIONS = MockCompletion(
'{"Instruction": "Convert each date from an informal description into a'
' MM/DD/YYYY format.", "Demonstrations": "Fifth of November 2024 ->'
Expand Down

0 comments on commit 72d97cc

Please sign in to comment.