Skip to content

Commit

Permalink
Fixed hf unit test; removed pop attributes in OpenAi completion.
Browse files Browse the repository at this point in the history
  • Loading branch information
Am1n3e committed Feb 22, 2024
1 parent 5c4e0aa commit c682de1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,13 @@ def sameuntil_chunks(xs, size):
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
):
inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)

until = request_args.pop("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
until = request_args.get("until", ["<|endoftext|>"])
request_args["temperature"] = request_args.get("temperature", 0)

response = oa_completion(
Expand All @@ -278,7 +277,7 @@ def sameuntil_chunks(xs, size):
max_tokens=self.max_gen_toks,
stop=until,
seed=self.seed,
**request_args,
**{k: v for k, v in request_args.items() if k not in ["do_sample", "max_gen_toks"]},
)
for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text")
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class Test_HFLM:
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until: list[Instance] = generate_until_task.instances
rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
Expand Down

0 comments on commit c682de1

Please sign in to comment.