Skip to content

Commit

Permalink
Merge branch 'main' into prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb committed Sep 16, 2024
2 parents dac8b53 + fb963f0 commit 4eecbab
Show file tree
Hide file tree
Showing 466 changed files with 10,165 additions and 373 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

*Latest News 📣*

- [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features.
- [2024/07] [API model](docs/API_guide.md) support has been updated and refactored, introducing support for batched and async requests, and making it significantly easier to customize and use for your own purposes. **To run Llama 405B, we recommend using VLLM's OpenAI-compliant API to host the model, and use the `local-completions` model type to evaluate the model.**
- [2024/07] New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.

Expand Down
7 changes: 0 additions & 7 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
)

if (
args.num_fewshot is None or args.num_fewshot == 0
) and args.fewshot_as_multiturn:
raise ValueError(
"If fewshot_as_multiturn is set, num_fewshot must be greater than 0."
)

if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Expand Down
114 changes: 105 additions & 9 deletions lm_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import os
from typing import Dict, List, Optional, Tuple, Type, TypeVar
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union

import transformers
from sqlitedict import SqliteDict
Expand Down Expand Up @@ -192,15 +192,13 @@ def tokenizer_name(self) -> str:
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)

@property
def chat_template(self) -> str:
"""Must be defined for LM subclasses that implement Chat Templating.
Should return the structure of the chat template applied to user/assistant messages.
This is used only to save in the experiment results for reproducibility.
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
"""
raise NotImplementedError(
"To use this model with chat templates, please implement the 'chat_template' property."
)

return ""

def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook
Expand Down Expand Up @@ -316,6 +314,8 @@ class TemplateLM(LM):
and boilerplate often included in other LM subclasses.
"""

tokenizer = None

@property
@abc.abstractmethod
def eot_token_id(self):
Expand Down Expand Up @@ -386,3 +386,99 @@ def loglikelihood_rolling(
@abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
pass

def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
"""
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
The template selection logic is adapted from the Transformers library's `apply_chat_template`
method in the Tokenizer class. The original implementation can be found at:
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
This method ensures that the right template is chosen based on the following:
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
1. If the model's tokenizer has multiple templates:
a. Use the specified template if it exists in the dictionary.
b. Use the default template from the list if no specific template is provided.
c. Raise an error if no default template exists and no specific template is provided.
2. If the model's tokenizer has a single template or no template:
a. Use the tokenizer's chat template if available.
b. Fall back to the default chat template if no tokenizer chat template exists.
Args:
chat_template (Union[bool, str]): Specifies the chat template to use.
- If False or None, no template is applied.
- If True, the default or only available template is used.
- If a string, the template with the matching name is used.
Returns:
Optional[str]: The selected chat template, or None if no template is applied.
"""
if self.tokenizer is None:
return ""

if chat_template is False or chat_template is None:
eval_logger.warning(
"model.chat_template was called with the chat_template set to False or None. "
"Therefore no chat template will be applied. Make sure this is an intended behavior."
)
return None

# Convert boolean chat_template to None to ensure compatibility with the adapted logic
if isinstance(chat_template, bool):
chat_template = None
using_default_template = False

# First, handle the cases when the model has a dict of multiple templates
template = self.tokenizer.chat_template or self.tokenizer.default_chat_template

if isinstance(template, dict):
using_default_dict = self.tokenizer.chat_template is None

if chat_template is not None:
if chat_template in template:
selected_template = template[chat_template]
if using_default_dict:
using_default_template = True
else:
raise ValueError(
f"The specified chat template '{chat_template}' is not available. "
f"Available template names are {sorted(template.keys())}."
)
else:
# If user didn't pass a chat template, use the default template from the dict
if "default" in template:
selected_template = template["default"]
using_default_template = True
else:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template.keys())}."
)

# Cases when the model has a single template or no template
else:
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if isinstance(chat_template, str):
eval_logger.warning(
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
"Using the tokenizer's chat template or the default template instead."
)
if self.tokenizer.chat_template is not None:
selected_template = self.tokenizer.chat_template
else:
selected_template = self.tokenizer.default_chat_template
using_default_template = True

if using_default_template:
eval_logger.warning(
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)

return selected_template
99 changes: 73 additions & 26 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
Expand Down Expand Up @@ -378,6 +379,10 @@ def doc_to_text(self, doc):
def doc_to_target(self, doc):
pass

# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
raise NotImplementedError

def build_all_requests(
self,
*,
Expand Down Expand Up @@ -736,6 +741,10 @@ def __init__(
)
self.OUTPUT_TYPE = self.config.output_type

if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True

if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path

Expand Down Expand Up @@ -1049,8 +1058,8 @@ def fewshot_context(
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template: Callable
Chat template to be applied to the fewshot context.
:param chat_template:
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:returns: str
The fewshot context.
"""
Expand Down Expand Up @@ -1303,9 +1312,34 @@ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
else:
raise TypeError

def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
doc_to_image = self.config.doc_to_image
else:
return None

if isinstance(doc_to_image, list):
image_feature = [
self.doc_to_image(doc, feature) for feature in doc_to_image
]
return [feature for feature in image_feature if feature is not None]
elif isinstance(doc_to_image, str):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
return None

def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
aux_arguments = None

if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
Expand All @@ -1323,6 +1357,37 @@ def construct_requests(
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]

# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.

# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]

arguments.extend(aux_arguments)

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))

multimodal_arg = {}
if (
self.config.doc_to_image
): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}

if bool(multimodal_arg):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)

if self.OUTPUT_TYPE == "multiple_choice":
request_list = [
Instance(
request_type="loglikelihood",
Expand All @@ -1333,33 +1398,15 @@ def construct_requests(
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.

# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))

return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
)

def process_results(self, doc, results):
Expand Down Expand Up @@ -1571,7 +1618,7 @@ def __repr__(self):
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
f"num_samples={len(self.eval_docs)})",
)


Expand Down
28 changes: 21 additions & 7 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,12 @@ def _adjust_config(task_dict):
if check_integrity:
run_task_tests(task_list=tasks)

# hotfix: delete when chat_template fixed
try:
chat = lm.chat_template(apply_chat_template)
except: # noqa: E722
chat = None

if evaluation_tracker is not None:
evaluation_tracker.general_config_tracker.log_experiment_args(
model_source=model,
model_args=model_args,
system_instruction=system_instruction,
chat_template=chat,
chat_template=lm.chat_template(apply_chat_template),
fewshot_as_multiturn=fewshot_as_multiturn,
)

Expand Down Expand Up @@ -420,8 +414,28 @@ def evaluate(
for task_output in eval_tasks
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")

# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
incompatible_tasks = []
for task_output in eval_tasks:
task: Task = task_output.task

if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check

for task_output in eval_tasks:
task: Task = task_output.task

limit = get_sample_size(task, limit)
task.build_all_requests(
limit=limit,
Expand Down
2 changes: 2 additions & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
api_models,
dummy,
gguf,
hf_vlms,
huggingface,
mamba_lm,
nemo_lm,
Expand All @@ -12,6 +13,7 @@
optimum_lm,
textsynth,
vllm_causallms,
vllm_vlms,
)


Expand Down
Loading

0 comments on commit 4eecbab

Please sign in to comment.