Skip to content

Commit

Permalink
update openai wrapper to work with tiktoken interface and newest open…
Browse files Browse the repository at this point in the history
…ai version (#794)

* update openai wrapper to work with tiktoken interface

* update openai wrapper to work with tiktoken interface

* add deprecation note

* fix completion endpoint

* update to newest openai version

* monkey patch api key

* fix type

* fix issues

* fix issues

* edit

* fix typing

* openai

---------

Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Max Marion <[email protected]>
  • Loading branch information
3 people committed Dec 14, 2023
1 parent 5388dc0 commit 5cc4dd4
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 182 deletions.
7 changes: 3 additions & 4 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):

def get_metrics(self, is_train: bool = False):
if is_train:
raise NotImplementedError(
'You cannot use inference wrappers for training')
metrics = None
else:
metrics = self.eval_metrics

Expand All @@ -55,6 +54,7 @@ def rebatch(self, batch: Batch):
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
Expand All @@ -80,8 +80,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
torch.full((seqlen - output_logits.shape[0],), padding_tok),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)
Expand Down
88 changes: 49 additions & 39 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import logging
import os
import random
from time import sleep
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import torch
from composer.core.types import Batch
Expand All @@ -22,6 +23,9 @@
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
]
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

MAX_RETRIES = 10

Expand All @@ -30,20 +34,23 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
assert os.getenv(
'OPENAI_API_KEY'
) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
openai.api_key = os.getenv('OPENAI_API_KEY')
self.client = openai.OpenAI()
self.model_name = model_cfg['version']

def generate_completion(self, prompt: str, num_tokens: int):
raise NotImplementedError()

def process_result(self, completion: Optional[dict]):
def process_result(self, completion): # pyright: ignore
raise NotImplementedError()

def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):
Expand All @@ -52,45 +59,49 @@ def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):

def try_generate_completion(self, prompt: str, num_tokens: int):
try:
from openai.error import RateLimitError
from openai import APITimeoutError, RateLimitError
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
tries = 0
completion = None
delay = 1
while tries < MAX_RETRIES:
tries += 1
try:

completion = self.generate_completion(prompt, num_tokens)
break
except RateLimitError as e:
if 'You exceeded your current quota' in str(e._message):
if 'You exceeded your current quota' in str(
e._message): # pyright: ignore
raise e
sleep(60)
delay *= 2 * (1 + random.random())
sleep(delay)
continue
except Exception:
except APITimeoutError as e:
delay *= 2 * (1 + random.random())
sleep(delay)
continue

return completion


class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e

self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create(
self.model_name,
self.generate_completion = lambda prompt, num_tokens: self.client.chat.completions.create(
model=self.model_name,
messages=[{
'role':
'system',
'content':
model_cfg.get('sytsem_role_prompt',
'Please complete the following text: ')
}, {
'role': 'user',
'content': prompt
}],
Expand Down Expand Up @@ -162,6 +173,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# than what the continuation would expect.
# Get around this issue by retokenizing the batch to remove spacing from the continuation as well as
# decoding the whole continuation at once.
padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
output_logits_batch = []
batch = self.rebatch(batch)
for tokens, cont_idxs in zip(batch['input_ids'],
Expand All @@ -182,20 +194,21 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
if next_logit_tensor is not None:
output_logits = torch.cat([output_logits, next_logit_tensor])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
torch.full((seqlen - output_logits.shape[0],), padding_tok),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def process_result(self, completion: Optional[dict]):
assert isinstance(completion, dict)
if len(completion['choices']) > 0:
def process_result(self, completion: Optional[ChatCompletion]):
if completion is None:
raise ValueError("Couldn't generate model output")

if len(completion.choices) > 0:
tensors = []
for t in self.tokenizer(completion['choices'][0]['message']
['content'])['input_ids']:
for t in self.tokenizer(
completion.choices[0].message.content)['input_ids']:
tensors.append(
self.tokenizer.construct_logit_tensor(
{self.tokenizer.decode([t]): 0.0}))
Expand All @@ -213,29 +226,26 @@ class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e

self.generate_completion = lambda prompt, num_tokens: openai.Completion.create(
engine=self.model_name,
# TODO: this will be deprecated
self.generate_completion = lambda prompt, num_tokens: self.client.completions.create(
model=self.model_name,
prompt=prompt,
max_tokens=1,
max_tokens=num_tokens,
logprobs=5,
temperature=0.0)

def process_result(self, completion: Optional[dict]):
def process_result(self, completion: Optional[Completion]):
if completion is None:
raise ValueError("Couldn't generate model output")

assert isinstance(completion, dict)
if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0:
if TYPE_CHECKING:
assert isinstance(completion, Completion)
assert isinstance(completion.choices[0].logprobs, Logprobs)
assert isinstance(completion.choices[0].logprobs.top_logprobs, list)

if len(completion.choices[0].logprobs.top_logprobs[0]) > 0:
tensor = self.tokenizer.construct_logit_tensor(
dict(completion['choices'][0]['logprobs']['top_logprobs'][0]))
dict(completion.choices[0].logprobs.top_logprobs[0]))
return tensor
else:
# the model sometimes stops early even though we are still requesting tokens!
Expand Down
43 changes: 12 additions & 31 deletions mcli/mcli-openai-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ command: |
# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME
run_name: openai-eval
# gpu_num: #
# gpu_type: #
gpu_num: #
gpu_type: #
cluster: # replace with your cluster here!

image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest
Expand All @@ -25,41 +25,22 @@ parameters:
device_eval_batch_size: 4
models:
-
model_name: openai/davinci
model:
name: openai_causal_lm
version: davinci
tokenizer:
name: openai
kwargs:
name: davinci
-
model_name: openai/ada
model:
name: openai_causal_lm
version: ada
tokenizer:
name: openai
kwargs:
name: ada
-
model_name: openai/gpt-4
model_name: openai/gpt-3.5-turbo
model:
name: openai_chat
version: gpt-4
version: gpt-3.5-turbo
tokenizer:
name: openai
name: tiktoken
kwargs:
name: gpt-4
model_name: gpt-3.5-turbo
-
model_name: openai/gpt-3.5-turbo
model_name: openai/davinci
model:
name: openai_chat
version: gpt-3.5-turbo
name: openai_causal_lm
version: davinci
tokenizer:
name: openai
name: tiktoken
kwargs:
name: gpt-3.5-turbo
model_name: davinci

icl_tasks: 'eval/yamls/lm_tasks.yaml'
eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml'
icl_tasks: 'eval/yamls/lm_tasks_v0.2.yaml'
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
icl_tasks:
-
label: jeopardy
dataset_uri: eval/local_data/world_knowledge/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI
num_fewshot: [10]
dataset_uri: eval/local_data/world_knowledge/jeopardy_all.jsonl
num_fewshot: [3]
icl_task_type: language_modeling
continuation_delimiter: "\nAnswer: " # this separates questions from answers
has_categories: true
-
label: bigbench_qa_wikidata
dataset_uri: eval/local_data/world_knowledge/bigbench_qa_wikidata.jsonl # ADD YOUR OWN DATASET URI
num_fewshot: [10]
dataset_uri: eval/local_data/world_knowledge/bigbench_qa_wikidata.jsonl
num_fewshot: [3]
icl_task_type: language_modeling
-
label: lambada_openai
dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl
num_fewshot: [0]
label: bigbench_dyck_languages
dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_dyck_languages.jsonl
num_fewshot: [5]
icl_task_type: language_modeling
-
label: bigbench_conlang_translation
dataset_uri: eval/local_data/language_understanding/bigbench_conlang_translation.jsonl
label: lambada_openai
dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl
num_fewshot: [0]
icl_task_type: language_modeling
-
label: bigbench_dyck_languages
dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_dyck_languages.jsonl
num_fewshot: [10]
icl_task_type: language_modeling
-
label: bigbench_cs_algorithms
dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_cs_algorithms.jsonl
Expand All @@ -34,35 +29,30 @@ icl_tasks:
-
label: bigbench_operators
dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_operators.jsonl
num_fewshot: [10]
icl_task_type: language_modeling
-
label: bigbench_repeat_copy_logic
dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_repeat_copy_logic.jsonl
num_fewshot: [10]
num_fewshot: [3]
icl_task_type: language_modeling
-
label: simple_arithmetic_nospaces
dataset_uri: eval/local_data/symbolic_problem_solving/simple_arithmetic_nospaces.jsonl
num_fewshot: [10]
num_fewshot: [5]
icl_task_type: language_modeling
-
label: simple_arithmetic_withspaces
dataset_uri: eval/local_data/symbolic_problem_solving/simple_arithmetic_withspaces.jsonl
num_fewshot: [10]
num_fewshot: [5]
icl_task_type: language_modeling
-
label: pubmed_qa_labeled
dataset_uri: eval/local_data/reading_comprehension/pubmed_qa_labeled.jsonl # ADD YOUR OWN DATASET URI
dataset_uri: eval/local_data/reading_comprehension/pubmed_qa_labeled.jsonl
num_fewshot: [10]
icl_task_type: language_modeling
-
label: squad
dataset_uri: eval/local_data/reading_comprehension/squad.jsonl # ADD YOUR OWN DATASET URI
num_fewshot: [10]
dataset_uri: eval/local_data/reading_comprehension/squad.jsonl
num_fewshot: [3]
icl_task_type: language_modeling
-
label: coqa
dataset_uri: eval/local_data/reading_comprehension/coqa.jsonl # ADD YOUR OWN DATASET URI
dataset_uri: eval/local_data/reading_comprehension/coqa.jsonl
num_fewshot: [0]
icl_task_type: language_modeling
Loading

0 comments on commit 5cc4dd4

Please sign in to comment.