From 6db7430f555f5a2c468ec2eeed8ff9f73684b92a Mon Sep 17 00:00:00 2001 From: Rajas Bansal Date: Thu, 12 Oct 2023 17:11:21 -0700 Subject: [PATCH] Refuel LLM integration with autolabel (#595) * use refuel llm * fix test * add refuel model name and change classification * fix returns token probs --------- Co-authored-by: Rajas Bansal --- src/autolabel/confidence.py | 4 +- .../label_diversity_example_selector.py | 6 ++- src/autolabel/labeler.py | 2 +- src/autolabel/models/refuel.py | 41 ++++++++++--------- src/autolabel/tasks/base.py | 32 +++++++++++++-- src/autolabel/tasks/classification.py | 17 ++++++-- src/autolabel/tasks/question_answering.py | 12 +++++- tests/unit/llm_test.py | 7 ++-- 8 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/autolabel/confidence.py b/src/autolabel/confidence.py index 7152908f..6c753597 100644 --- a/src/autolabel/confidence.py +++ b/src/autolabel/confidence.py @@ -27,7 +27,7 @@ def __init__( ) -> None: self.score_type = score_type self.llm = llm - self.tokens_to_ignore = {""} + self.tokens_to_ignore = {"", "", "\\n"} self.SUPPORTED_CALCULATORS = { "logprob_average": self.logprob_average, "p_true": self.p_true, @@ -54,7 +54,7 @@ def logprob_average( logprob_cumulative, count = 0, 0 for token in logprobs: token_str = list(token.keys())[0] - if token_str not in self.tokens_to_ignore: + if token_str.strip() not in self.tokens_to_ignore: logprob_cumulative += ( token[token_str] if token[token_str] >= 0 diff --git a/src/autolabel/few_shot/label_diversity_example_selector.py b/src/autolabel/few_shot/label_diversity_example_selector.py index 43ddee35..2426404a 100644 --- a/src/autolabel/few_shot/label_diversity_example_selector.py +++ b/src/autolabel/few_shot/label_diversity_example_selector.py @@ -106,7 +106,9 @@ def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on label diversity and semantic similarity.""" # Get the docs with the highest similarity for each label. if self.input_keys: - input_variables = {key: input_variables[key] for key in self.input_keys} + input_variables = { + str(key): str(input_variables[key]) for key in self.input_keys + } query = " ".join(sorted_values(input_variables)) num_examples_per_label = math.ceil(self.k / self.num_labels) example_docs = self.vectorstore.label_diversity_similarity_search( @@ -146,7 +148,7 @@ def from_examples( """ if input_keys: string_examples = [ - " ".join(sorted_values({k: eg[k] for k in input_keys})) + " ".join(sorted_values({str(k): str(eg[k]) for k in input_keys})) for eg in examples ] else: diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index be8b3bb7..ad8a3360 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -444,7 +444,7 @@ def plan( table, show_header=False, console=self.console, styles=COST_TABLE_STYLES ) self.console.rule("Prompt Example") - self.console.print(f"{prompt_list[0]}") + self.console.print(f"{prompt_list[0]}", markup=False) self.console.rule() async def async_run_transform( diff --git a/src/autolabel/models/refuel.py b/src/autolabel/models/refuel.py index 4a69607b..2d3a4297 100644 --- a/src/autolabel/models/refuel.py +++ b/src/autolabel/models/refuel.py @@ -23,7 +23,6 @@ class RefuelLLM(BaseModel): DEFAULT_PARAMS = { "max_new_tokens": 128, - "temperature": 0.0, } def __init__( @@ -41,8 +40,7 @@ def __init__( self.model_params = {**self.DEFAULT_PARAMS, **model_params} # initialize runtime - self.BASE_API = "https://refuel-llm.refuel.ai/" - self.SEP_REPLACEMENT_TOKEN = "@@" + self.BASE_API = f"https://llm.refuel.ai/models/{self.model_name}/generate" self.REFUEL_API_ENV = "REFUEL_API_KEY" if self.REFUEL_API_ENV in os.environ and os.environ[self.REFUEL_API_ENV]: self.REFUEL_API_KEY = os.environ[self.REFUEL_API_ENV] @@ -60,8 +58,9 @@ def __init__( ) def _label_with_retry(self, prompt: str) -> requests.Response: payload = { - "data": {"model_input": prompt, "model_params": {**self.model_params}}, - "task": "generate", + "input": prompt, + "params": {**self.model_params}, + "confidence": self.config.confidence(), } headers = {"refuel_api_key": self.REFUEL_API_KEY} response = requests.post(self.BASE_API, json=payload, headers=headers) @@ -74,20 +73,20 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult: errors = [] for prompt in prompts: try: - if self.SEP_REPLACEMENT_TOKEN in prompt: - logger.warning( - f"""Current prompt contains {self.SEP_REPLACEMENT_TOKEN} - which is currently used as a separator token by refuel - llm. It is highly recommended to avoid having any - occurences of this substring in the prompt. - """ - ) - separated_prompt = prompt.replace("\n", self.SEP_REPLACEMENT_TOKEN) - response = self._label_with_retry(separated_prompt) - response = json.loads(response.json()["body"]).replace( - self.SEP_REPLACEMENT_TOKEN, "\n" + response = self._label_with_retry(prompt) + response = json.loads(response.json()) + generations.append( + [ + Generation( + text=response["generated_text"], + generation_info={ + "logprobs": {"top_logprobs": response["logprobs"]} + } + if self.config.confidence() + else None, + ) + ] ) - generations.append([Generation(text=response)]) errors.append(None) except Exception as e: # This signifies an error in generating the response using RefuelLLm @@ -96,7 +95,9 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult: ) generations.append([Generation(text="")]) errors.append( - LabelingError(error_type=ErrorType.LLM_PROVIDER_ERROR, error=e) + LabelingError( + error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e) + ) ) return RefuelLLMResult(generations=generations, errors=errors) @@ -104,4 +105,4 @@ def get_cost(self, prompt: str, label: Optional[str] = "") -> float: return 0 def returns_token_probs(self) -> bool: - return False + return True diff --git a/src/autolabel/tasks/base.py b/src/autolabel/tasks/base.py index 4a355b43..ba359177 100644 --- a/src/autolabel/tasks/base.py +++ b/src/autolabel/tasks/base.py @@ -16,6 +16,7 @@ TaskType, LabelingError, ErrorType, + ModelProvider, ) from autolabel.utils import ( get_format_variables, @@ -30,6 +31,17 @@ class BaseTask(ABC): ZERO_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nNow I want you to label the following example:\n{current_example}" FEW_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nSome examples with their output answers are provided below:\n\n{seed_examples}\n\nNow I want you to label the following example:\n{current_example}" + ZERO_SHOT_TEMPLATE_REFUEL_LLM = """ + [INST] <> + {task_guidelines}{output_guidelines} + <> + {current_example}[/INST]\n""" + FEW_SHOT_TEMPLATE_REFUEL_LLM = """ + [INST] <> + {task_guidelines}{output_guidelines}\n{seed_examples} + <> + {current_example}[/INST]\n""" + # Downstream classes should override these NULL_LABEL_TOKEN = "NO_LABEL" DEFAULT_TASK_GUIDELINES = "" @@ -39,6 +51,8 @@ class BaseTask(ABC): def __init__(self, config: AutolabelConfig) -> None: self.config = config + is_refuel_llm = self.config.provider() == ModelProvider.REFUEL + # Update the default prompt template with the prompt template from the config self.task_guidelines = ( self.config.task_guidelines() or self.DEFAULT_TASK_GUIDELINES @@ -48,14 +62,24 @@ def __init__(self, config: AutolabelConfig) -> None: ) if self._is_few_shot_mode(): + few_shot_template = ( + self.FEW_SHOT_TEMPLATE_REFUEL_LLM + if is_refuel_llm + else self.FEW_SHOT_TEMPLATE + ) self.prompt_template = PromptTemplate( - input_variables=get_format_variables(self.FEW_SHOT_TEMPLATE), - template=self.FEW_SHOT_TEMPLATE, + input_variables=get_format_variables(few_shot_template), + template=few_shot_template, ) else: + zero_shot_template = ( + self.ZERO_SHOT_TEMPLATE_REFUEL_LLM + if is_refuel_llm + else self.ZERO_SHOT_TEMPLATE + ) self.prompt_template = PromptTemplate( - input_variables=get_format_variables(self.ZERO_SHOT_TEMPLATE), - template=self.ZERO_SHOT_TEMPLATE, + input_variables=get_format_variables(zero_shot_template), + template=zero_shot_template, ) self.dataset_generation_guidelines = ( diff --git a/src/autolabel/tasks/classification.py b/src/autolabel/tasks/classification.py index 37ee4fc4..45666bc0 100644 --- a/src/autolabel/tasks/classification.py +++ b/src/autolabel/tasks/classification.py @@ -7,7 +7,7 @@ from autolabel.confidence import ConfidenceCalculator from autolabel.configs import AutolabelConfig -from autolabel.schema import LLMAnnotation, MetricType, MetricResult +from autolabel.schema import LLMAnnotation, MetricType, MetricResult, ModelProvider from autolabel.tasks import BaseTask from autolabel.utils import get_format_variables from autolabel.tasks.utils import filter_unlabeled_examples @@ -69,8 +69,19 @@ def construct_prompt( ) num_labels = len(labels_list) + is_refuel_llm = self.config.provider() == ModelProvider.REFUEL + + if is_refuel_llm: + labels = ( + ", ".join([f'\\"{i}\\"' for i in labels_list[:-1]]) + + " or " + + f'\\"{labels_list[-1]}\\"' + ) + else: + labels = "\n".join(labels_list) + fmt_task_guidelines = self.task_guidelines.format( - num_labels=num_labels, labels="\n".join(labels_list) + num_labels=num_labels, labels=labels ) # prepare seed examples @@ -100,7 +111,7 @@ def construct_prompt( return self.prompt_template.format( task_guidelines=fmt_task_guidelines, output_guidelines=self.output_guidelines, - seed_examples="\n\n".join(fmt_examples), + seed_examples="\n".join(fmt_examples), current_example=current_example, ) else: diff --git a/src/autolabel/tasks/question_answering.py b/src/autolabel/tasks/question_answering.py index bfbada13..ed242378 100644 --- a/src/autolabel/tasks/question_answering.py +++ b/src/autolabel/tasks/question_answering.py @@ -7,7 +7,13 @@ from autolabel.confidence import ConfidenceCalculator from autolabel.configs import AutolabelConfig -from autolabel.schema import LLMAnnotation, MetricType, MetricResult, F1Type +from autolabel.schema import ( + LLMAnnotation, + MetricType, + MetricResult, + F1Type, + ModelProvider, +) from autolabel.tasks import BaseTask from autolabel.tasks.utils import normalize_text from autolabel.utils import get_format_variables @@ -32,6 +38,10 @@ class QuestionAnsweringTask(BaseTask): GENERATE_EXPLANATION_PROMPT = "You are an expert at providing a well reasoned explanation for the output of a given task. \n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION\nYou will be given an input example and the corresponding output. You will be given a question and an answer. Your job is to provide an explanation for why the answer is correct for the task above.\nThink step by step and generate an explanation. The last line of the explanation should be - So, the answer is