Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Dec 13, 2023
1 parent 4beaa05 commit 5824bc5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
16 changes: 10 additions & 6 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import random
from time import sleep
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import torch
from composer.core.types import Batch
Expand Down Expand Up @@ -59,7 +59,7 @@ def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):

def try_generate_completion(self, prompt: str, num_tokens: int):
try:
from openai import RateLimitError, APITimeoutError
from openai import APITimeoutError, RateLimitError
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
Expand All @@ -84,7 +84,7 @@ def try_generate_completion(self, prompt: str, num_tokens: int):
delay *= 2 * (1 + random.random())
sleep(delay)
continue

return completion


Expand Down Expand Up @@ -203,6 +203,9 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def process_result(self, completion: Optional[ChatCompletion]):
if completion is None:
raise ValueError("Couldn't generate model output")

if len(completion.choices) > 0:
tensors = []
for t in self.tokenizer(
Expand Down Expand Up @@ -236,9 +239,10 @@ def process_result(self, completion: Optional[Completion]):
if completion is None:
raise ValueError("Couldn't generate model output")

assert isinstance(completion, Completion)
assert isinstance(completion.choices[0].logprobs, Logprobs)
assert isinstance(completion.choices[0].logprobs.top_logprobs, list)
if TYPE_CHECKING:
assert isinstance(completion, Completion)
assert isinstance(completion.choices[0].logprobs, Logprobs)
assert isinstance(completion.choices[0].logprobs.top_logprobs, list)

if len(completion.choices[0].logprobs.top_logprobs[0]) > 0:
tensor = self.tokenizer.construct_logit_tensor(
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval/yamls/lm_tasks_v0.2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ icl_tasks:
label: coqa
dataset_uri: eval/local_data/reading_comprehension/coqa.jsonl
num_fewshot: [0]
icl_task_type: language_modeling
icl_task_type: language_modeling
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ def load_icl_config():
class MockTopLogProb:

def __init__(self, expected_token: str) -> None:
setattr(self, 'top_logprobs', [{expected_token: 0}])
self.top_logprobs = [{expected_token: 0}]


class MockLogprob:

def __init__(self, expected_token: str) -> None:
setattr(self, 'logprobs', MockTopLogProb(expected_token))
self.logprobs = MockTopLogProb(expected_token)


class MockCompletion:

def __init__(self, expected_token: str) -> None:
setattr(self, 'choices', [MockLogprob(expected_token)])
self.choices = [MockLogprob(expected_token)]


class MockContent:
Expand Down

0 comments on commit 5824bc5

Please sign in to comment.