Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal prototyping #2243

Merged
merged 69 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
5b62a52
add WIP hf vlm class
haileyschoelkopf Jul 2, 2024
34a079e
add doc_to_image
lintangsutawika Jul 2, 2024
8bce8cf
add mmmu tasks
lintangsutawika Jul 2, 2024
6cc6e9c
Merge branch 'hailey-multimodal-prototyping' into multimodal-prototyping
haileyschoelkopf Jul 9, 2024
e4db76c
Merge branch 'main' into multimodal-prototyping
haileyschoelkopf Jul 9, 2024
1c94a54
fix merge conflicts
haileyschoelkopf Jul 9, 2024
9692aa0
add lintang's changes to hf_vlms.py
haileyschoelkopf Jul 9, 2024
90ba03a
fix doc_to_image
lintangsutawika Jul 10, 2024
aa6c50e
added yaml_path for config-loading
lintangsutawika Jul 10, 2024
7c76574
revert
lintangsutawika Jul 10, 2024
1b9deaa
add line to process str type v
lintangsutawika Jul 10, 2024
8db0a47
update
lintangsutawika Jul 11, 2024
9b9ca7b
modeling cleanup
haileyschoelkopf Jul 11, 2024
df7fee6
merge with lintang-multimodal-prototyping
haileyschoelkopf Aug 2, 2024
8d92a68
add aggregation for mmmu
haileyschoelkopf Aug 2, 2024
f410d35
rewrite MMMU processing code based on only MMMU authors' repo (doc_to…
haileyschoelkopf Aug 7, 2024
ebf54d8
implemented doc_to_image
lintangsutawika Aug 8, 2024
941b502
update doc_to_image to accept list of features
lintangsutawika Aug 8, 2024
8e4c1d6
update functions
lintangsutawika Aug 8, 2024
63bcbc5
readd image processed
lintangsutawika Aug 10, 2024
15dda35
update args process
lintangsutawika Aug 11, 2024
d811a3a
bugfix for repeated images fed to model
haileyschoelkopf Aug 23, 2024
2242ed3
push WIP loglikelihood code
haileyschoelkopf Sep 3, 2024
be14ac1
commit most recent code (generative ; qwen2-vl testing)
haileyschoelkopf Sep 9, 2024
7516b88
preliminary image_token_id handling
haileyschoelkopf Sep 9, 2024
5a65d10
small mmmu update: some qs have >4 mcqa options
haileyschoelkopf Sep 9, 2024
54d317d
push updated modeling code
haileyschoelkopf Sep 10, 2024
9789f83
merge with most recent main
haileyschoelkopf Sep 10, 2024
40a48c2
use processor.apply_chat_template
baberabb Sep 11, 2024
8148fe4
add mathvista draft
baberabb Sep 12, 2024
38b6fe3
nit
baberabb Sep 12, 2024
d348106
nit
baberabb Sep 12, 2024
295a825
ensure no footguns in text<>multimodal LM<>task incompatibility
haileyschoelkopf Sep 12, 2024
2c9fd79
add notification to readme regarding launch of prototype!
haileyschoelkopf Sep 12, 2024
80e1711
fix compatibility check
haileyschoelkopf Sep 12, 2024
d84b9fc
reorganize mmmu configs
haileyschoelkopf Sep 12, 2024
207778f
chat_template=None
baberabb Sep 12, 2024
45d0f3c
add interleave chat_template
baberabb Sep 12, 2024
5d55e6c
add condition
baberabb Sep 12, 2024
af8382b
add max_images; interleave=true
baberabb Sep 13, 2024
f1185a2
nit
baberabb Sep 13, 2024
c67d810
testmini_mcq
baberabb Sep 13, 2024
5848698
nit
baberabb Sep 13, 2024
294dc01
pass image string; convert img
baberabb Sep 13, 2024
b45d295
add vllm
baberabb Sep 13, 2024
db554e0
add init
baberabb Sep 13, 2024
90b4601
vlm add multi attr
baberabb Sep 13, 2024
94731a7
fixup
baberabb Sep 13, 2024
d5cbc48
pass max images to vllm model init
baberabb Sep 13, 2024
3b4ce83
nit
baberabb Sep 13, 2024
357cf64
encoding to device
baberabb Sep 13, 2024
f9cf90e
fix HFMultimodalLM.chat_template ?
haileyschoelkopf Sep 13, 2024
f04be6c
Merge branch 'multimodal-prototyping' of https://github.com/EleutherA…
haileyschoelkopf Sep 13, 2024
8349e12
add mmmu readme
baberabb Sep 13, 2024
805a115
Merge branch 'main' into multimodal-prototyping
baberabb Sep 13, 2024
05f0dd6
remove erroneous prints
haileyschoelkopf Sep 13, 2024
4623768
Merge branch 'multimodal-prototyping' of https://github.com/EleutherA…
haileyschoelkopf Sep 13, 2024
9ddb2ec
use HFMultimodalLM.chat_template ; restore tasks/__init__.py
haileyschoelkopf Sep 13, 2024
d1aadbc
add docstring for replace_placeholders in utils
haileyschoelkopf Sep 13, 2024
19d6874
fix `replace_placeholders`; set image_string=None
baberabb Sep 13, 2024
5665bc6
fix typo
haileyschoelkopf Sep 13, 2024
31e5ab0
cleanup + fix merge conflicts
haileyschoelkopf Sep 13, 2024
2d0c1b0
Merge branch 'multimodal-prototyping' of https://github.com/EleutherA…
haileyschoelkopf Sep 13, 2024
c0b585d
update MMMU readme
haileyschoelkopf Sep 13, 2024
5c0bd54
del mathvista
baberabb Sep 13, 2024
a3bb2f1
add some sample scores
haileyschoelkopf Sep 13, 2024
5f76efd
Update README.md
haileyschoelkopf Sep 13, 2024
b3e87ae
add log msg for image_string value
haileyschoelkopf Sep 13, 2024
d85c3b6
Merge branch 'multimodal-prototyping' of https://github.com/EleutherA…
haileyschoelkopf Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

*Latest News 📣*

- [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features.
- [2024/07] [API model](docs/API_guide.md) support has been updated and refactored, introducing support for batched and async requests, and making it significantly easier to customize and use for your own purposes. **To run Llama 405B, we recommend using VLLM's OpenAI-compliant API to host the model, and use the `local-completions` model type to evaluate the model.**
- [2024/07] New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.

Expand Down
99 changes: 73 additions & 26 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
Expand Down Expand Up @@ -377,6 +378,10 @@ def doc_to_text(self, doc):
def doc_to_target(self, doc):
pass

# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
raise NotImplementedError

def build_all_requests(
self,
*,
Expand Down Expand Up @@ -735,6 +740,10 @@ def __init__(
)
self.OUTPUT_TYPE = self.config.output_type

if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True

if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path

Expand Down Expand Up @@ -1042,8 +1051,8 @@ def fewshot_context(
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template: Callable
Chat template to be applied to the fewshot context.
:param chat_template:
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:returns: str
The fewshot context.
"""
Expand Down Expand Up @@ -1279,9 +1288,34 @@ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
else:
raise TypeError

def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
doc_to_image = self.config.doc_to_image
else:
return None

if isinstance(doc_to_image, list):
image_feature = [
self.doc_to_image(doc, feature) for feature in doc_to_image
]
return [feature for feature in image_feature if feature is not None]
elif isinstance(doc_to_image, str):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
return None

def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
aux_arguments = None

if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
Expand All @@ -1299,6 +1333,37 @@ def construct_requests(
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]

# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.

# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]

arguments.extend(aux_arguments)

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))

multimodal_arg = {}
if (
self.config.doc_to_image
): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}

if bool(multimodal_arg):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)

if self.OUTPUT_TYPE == "multiple_choice":
request_list = [
Instance(
request_type="loglikelihood",
Expand All @@ -1309,33 +1374,15 @@ def construct_requests(
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.

# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))

return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
)

def process_results(self, doc, results):
Expand Down Expand Up @@ -1547,7 +1594,7 @@ def __repr__(self):
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
f"num_samples={len(self.eval_docs)})",
)


Expand Down
20 changes: 20 additions & 0 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,28 @@ def evaluate(
for task_output in eval_tasks
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")

# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
incompatible_tasks = []
for task_output in eval_tasks:
task: Task = task_output.task

if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check

for task_output in eval_tasks:
task: Task = task_output.task

limit = get_sample_size(task, limit)
task.build_all_requests(
limit=limit,
Expand Down
2 changes: 2 additions & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
api_models,
dummy,
gguf,
hf_vlms,
huggingface,
mamba_lm,
nemo_lm,
Expand All @@ -12,6 +13,7 @@
optimum_lm,
textsynth,
vllm_causallms,
vllm_vlms,
)


Expand Down
Loading
Loading