Skip to content

Commit

Permalink
Refuel LLM integration with autolabel (#595)
Browse files Browse the repository at this point in the history
* use refuel llm

* fix test

* add refuel model name and change classification

* fix returns token probs

---------

Co-authored-by: Rajas Bansal <[email protected]>
  • Loading branch information
rajasbansal and rajasbansal committed Oct 13, 2023
1 parent a374853 commit 6db7430
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/autolabel/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
) -> None:
self.score_type = score_type
self.llm = llm
self.tokens_to_ignore = {"<unk>"}
self.tokens_to_ignore = {"<unk>", "", "\\n"}
self.SUPPORTED_CALCULATORS = {
"logprob_average": self.logprob_average,
"p_true": self.p_true,
Expand All @@ -54,7 +54,7 @@ def logprob_average(
logprob_cumulative, count = 0, 0
for token in logprobs:
token_str = list(token.keys())[0]
if token_str not in self.tokens_to_ignore:
if token_str.strip() not in self.tokens_to_ignore:
logprob_cumulative += (
token[token_str]
if token[token_str] >= 0
Expand Down
6 changes: 4 additions & 2 deletions src/autolabel/few_shot/label_diversity_example_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on label diversity and semantic similarity."""
# Get the docs with the highest similarity for each label.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
input_variables = {
str(key): str(input_variables[key]) for key in self.input_keys
}
query = " ".join(sorted_values(input_variables))
num_examples_per_label = math.ceil(self.k / self.num_labels)
example_docs = self.vectorstore.label_diversity_similarity_search(
Expand Down Expand Up @@ -146,7 +148,7 @@ def from_examples(
"""
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
" ".join(sorted_values({str(k): str(eg[k]) for k in input_keys}))
for eg in examples
]
else:
Expand Down
2 changes: 1 addition & 1 deletion src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def plan(
table, show_header=False, console=self.console, styles=COST_TABLE_STYLES
)
self.console.rule("Prompt Example")
self.console.print(f"{prompt_list[0]}")
self.console.print(f"{prompt_list[0]}", markup=False)
self.console.rule()

async def async_run_transform(
Expand Down
41 changes: 21 additions & 20 deletions src/autolabel/models/refuel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
class RefuelLLM(BaseModel):
DEFAULT_PARAMS = {
"max_new_tokens": 128,
"temperature": 0.0,
}

def __init__(
Expand All @@ -41,8 +40,7 @@ def __init__(
self.model_params = {**self.DEFAULT_PARAMS, **model_params}

# initialize runtime
self.BASE_API = "https://refuel-llm.refuel.ai/"
self.SEP_REPLACEMENT_TOKEN = "@@"
self.BASE_API = f"https://llm.refuel.ai/models/{self.model_name}/generate"
self.REFUEL_API_ENV = "REFUEL_API_KEY"
if self.REFUEL_API_ENV in os.environ and os.environ[self.REFUEL_API_ENV]:
self.REFUEL_API_KEY = os.environ[self.REFUEL_API_ENV]
Expand All @@ -60,8 +58,9 @@ def __init__(
)
def _label_with_retry(self, prompt: str) -> requests.Response:
payload = {
"data": {"model_input": prompt, "model_params": {**self.model_params}},
"task": "generate",
"input": prompt,
"params": {**self.model_params},
"confidence": self.config.confidence(),
}
headers = {"refuel_api_key": self.REFUEL_API_KEY}
response = requests.post(self.BASE_API, json=payload, headers=headers)
Expand All @@ -74,20 +73,20 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
errors = []
for prompt in prompts:
try:
if self.SEP_REPLACEMENT_TOKEN in prompt:
logger.warning(
f"""Current prompt contains {self.SEP_REPLACEMENT_TOKEN}
which is currently used as a separator token by refuel
llm. It is highly recommended to avoid having any
occurences of this substring in the prompt.
"""
)
separated_prompt = prompt.replace("\n", self.SEP_REPLACEMENT_TOKEN)
response = self._label_with_retry(separated_prompt)
response = json.loads(response.json()["body"]).replace(
self.SEP_REPLACEMENT_TOKEN, "\n"
response = self._label_with_retry(prompt)
response = json.loads(response.json())
generations.append(
[
Generation(
text=response["generated_text"],
generation_info={
"logprobs": {"top_logprobs": response["logprobs"]}
}
if self.config.confidence()
else None,
)
]
)
generations.append([Generation(text=response)])
errors.append(None)
except Exception as e:
# This signifies an error in generating the response using RefuelLLm
Expand All @@ -96,12 +95,14 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
)
generations.append([Generation(text="")])
errors.append(
LabelingError(error_type=ErrorType.LLM_PROVIDER_ERROR, error=e)
LabelingError(
error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
)
)
return RefuelLLMResult(generations=generations, errors=errors)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
return 0

def returns_token_probs(self) -> bool:
return False
return True
32 changes: 28 additions & 4 deletions src/autolabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TaskType,
LabelingError,
ErrorType,
ModelProvider,
)
from autolabel.utils import (
get_format_variables,
Expand All @@ -30,6 +31,17 @@ class BaseTask(ABC):
ZERO_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nNow I want you to label the following example:\n{current_example}"
FEW_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nSome examples with their output answers are provided below:\n\n{seed_examples}\n\nNow I want you to label the following example:\n{current_example}"

ZERO_SHOT_TEMPLATE_REFUEL_LLM = """
<s>[INST] <<SYS>>
{task_guidelines}{output_guidelines}
<<SYS>>
{current_example}[/INST]\n"""
FEW_SHOT_TEMPLATE_REFUEL_LLM = """
<s>[INST] <<SYS>>
{task_guidelines}{output_guidelines}\n{seed_examples}
<<SYS>>
{current_example}[/INST]\n"""

# Downstream classes should override these
NULL_LABEL_TOKEN = "NO_LABEL"
DEFAULT_TASK_GUIDELINES = ""
Expand All @@ -39,6 +51,8 @@ class BaseTask(ABC):
def __init__(self, config: AutolabelConfig) -> None:
self.config = config

is_refuel_llm = self.config.provider() == ModelProvider.REFUEL

# Update the default prompt template with the prompt template from the config
self.task_guidelines = (
self.config.task_guidelines() or self.DEFAULT_TASK_GUIDELINES
Expand All @@ -48,14 +62,24 @@ def __init__(self, config: AutolabelConfig) -> None:
)

if self._is_few_shot_mode():
few_shot_template = (
self.FEW_SHOT_TEMPLATE_REFUEL_LLM
if is_refuel_llm
else self.FEW_SHOT_TEMPLATE
)
self.prompt_template = PromptTemplate(
input_variables=get_format_variables(self.FEW_SHOT_TEMPLATE),
template=self.FEW_SHOT_TEMPLATE,
input_variables=get_format_variables(few_shot_template),
template=few_shot_template,
)
else:
zero_shot_template = (
self.ZERO_SHOT_TEMPLATE_REFUEL_LLM
if is_refuel_llm
else self.ZERO_SHOT_TEMPLATE
)
self.prompt_template = PromptTemplate(
input_variables=get_format_variables(self.ZERO_SHOT_TEMPLATE),
template=self.ZERO_SHOT_TEMPLATE,
input_variables=get_format_variables(zero_shot_template),
template=zero_shot_template,
)

self.dataset_generation_guidelines = (
Expand Down
17 changes: 14 additions & 3 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from autolabel.confidence import ConfidenceCalculator
from autolabel.configs import AutolabelConfig
from autolabel.schema import LLMAnnotation, MetricType, MetricResult
from autolabel.schema import LLMAnnotation, MetricType, MetricResult, ModelProvider
from autolabel.tasks import BaseTask
from autolabel.utils import get_format_variables
from autolabel.tasks.utils import filter_unlabeled_examples
Expand Down Expand Up @@ -69,8 +69,19 @@ def construct_prompt(
)
num_labels = len(labels_list)

is_refuel_llm = self.config.provider() == ModelProvider.REFUEL

if is_refuel_llm:
labels = (
", ".join([f'\\"{i}\\"' for i in labels_list[:-1]])
+ " or "
+ f'\\"{labels_list[-1]}\\"'
)
else:
labels = "\n".join(labels_list)

fmt_task_guidelines = self.task_guidelines.format(
num_labels=num_labels, labels="\n".join(labels_list)
num_labels=num_labels, labels=labels
)

# prepare seed examples
Expand Down Expand Up @@ -100,7 +111,7 @@ def construct_prompt(
return self.prompt_template.format(
task_guidelines=fmt_task_guidelines,
output_guidelines=self.output_guidelines,
seed_examples="\n\n".join(fmt_examples),
seed_examples="\n".join(fmt_examples),
current_example=current_example,
)
else:
Expand Down
12 changes: 11 additions & 1 deletion src/autolabel/tasks/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from autolabel.confidence import ConfidenceCalculator
from autolabel.configs import AutolabelConfig
from autolabel.schema import LLMAnnotation, MetricType, MetricResult, F1Type
from autolabel.schema import (
LLMAnnotation,
MetricType,
MetricResult,
F1Type,
ModelProvider,
)
from autolabel.tasks import BaseTask
from autolabel.tasks.utils import normalize_text
from autolabel.utils import get_format_variables
Expand All @@ -32,6 +38,10 @@ class QuestionAnsweringTask(BaseTask):
GENERATE_EXPLANATION_PROMPT = "You are an expert at providing a well reasoned explanation for the output of a given task. \n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION\nYou will be given an input example and the corresponding output. You will be given a question and an answer. Your job is to provide an explanation for why the answer is correct for the task above.\nThink step by step and generate an explanation. The last line of the explanation should be - So, the answer is <label>.\n{labeled_example}\nExplanation: "

def __init__(self, config: AutolabelConfig) -> None:
is_refuel_llm = config.provider() == ModelProvider.REFUEL
if is_refuel_llm:
self.DEFAULT_OUTPUT_GUIDELINES = ""

super().__init__(config)
self.metrics = [
AccuracyMetric(),
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from autolabel.configs import AutolabelConfig
from autolabel.models.anthropic import AnthropicLLM
from autolabel.models.openai import OpenAILLM
Expand Down Expand Up @@ -211,7 +212,7 @@ def __init__(self, resp):
self.resp = resp

def json(self):
return {"body": self.resp}
return self.resp

def raise_for_status(self):
pass
Expand All @@ -222,7 +223,7 @@ def raise_for_status(self):
prompts = ["test1", "test2"]
mocker.patch(
"requests.post",
return_value=PostRequestMockResponse(resp='"Answers"'),
return_value=PostRequestMockResponse(resp='{"generated_text": "Answers"}'),
)
x = model.label(prompts)
assert [i[0].text for i in x.generations] == ["Answers", "Answers"]
Expand All @@ -242,7 +243,7 @@ def test_refuel_return_probs():
model = RefuelLLM(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json")
)
assert model.returns_token_probs() is False
assert model.returns_token_probs() is True


################### REFUEL TESTS #######################

0 comments on commit 6db7430

Please sign in to comment.