Skip to content

Commit

Permalink
Merge pull request #5 from ashwin-balakrishna96/ashwin/add_AI2D
Browse files Browse the repository at this point in the history
Add in AI2D Eval
  • Loading branch information
siddk authored Mar 27, 2024
2 parents 098224f + 11a1ba3 commit 118921e
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 11 deletions.
33 changes: 32 additions & 1 deletion vlm_eval/conf/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class VizWizFullDatasetConfig(DatasetConfig):
root_dir: Path = Path("../../datasets/vlm-evaluation")
index_file: Path = Path("datasets/vizwiz/metadata.json")
annotations_file: Path = Path("datasets/vizwiz/annotations-vizwiz-full.json")
questions_file: Path = Path("datasets/vqa-v2/questions-vizwiz-full.json")
questions_file: Path = Path("datasets/vizwiz/questions-vizwiz-full.json")


@dataclass
Expand Down Expand Up @@ -338,6 +338,33 @@ class PopeSlimDatasetConfig(DatasetConfig):
expected_examples: int = 3072


# === AI2D Datasets =>> Note: "Slim" defaults to k = 1024 examples ===
@dataclass
class AI2DFullDatasetConfig(DatasetConfig):
dataset_family: str = "ai2d"
dataset_id: str = "ai2d-full"
split: str = "eval"

expected_examples: int = 15501

root_dir: Path = Path("../../datasets/vlm-evaluation")
index_file: Path = Path("datasets/ai2d/metadata.json")
annotations_file: Path = Path("datasets/ai2d/metadata-full.json")


@dataclass
class AI2DSlimDatasetConfig(DatasetConfig):
dataset_family: str = "ai2d"
dataset_id: str = "ai2d-slim"
split: str = "eval"

expected_examples: int = 2048

root_dir: Path = Path("../../datasets/vlm-evaluation")
index_file: Path = Path("datasets/ai2d/metadata-slim-1024.json")
annotations_file: Path = Path("datasets/ai2d/metadata-slim-1024.json")



# === Define a Dataset Registry Enum for Reference / Validation =>> all *new* datasets must be added here! ===
@unique
Expand Down Expand Up @@ -384,6 +411,10 @@ class DatasetRegistry(Enum):
TALLYQA_SUBSAMPLED = TallyQASubsampledDatasetConfig
TALLYQA_SLIM = TallyQASlimDatasetConfig

# AI2D
AI2D_FULL = AI2DFullDatasetConfig
AI2D_SLIM = AI2DSlimDatasetConfig

@property
def dataset_id(self) -> str:
return self.value.dataset_id
Expand Down
6 changes: 4 additions & 2 deletions vlm_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]:
bbox_refer_prompt_fn = self.get_bbox_refer_chat_prompt_fn()
text_vqa_prompt_fn = self.get_vqa_chat_prompt_fn(uncertainty_aware=False)
captioning_prompt_fn = self.get_captioning_prompt_fn()
tally_qa_prompt_fn = self.get_mc_prompt_fn(choices=[str(i) for i in range(16)])
tally_qa_prompt_fn = self.get_mc_prompt_fn()
ai2d_prompt_fn = self.get_mc_prompt_fn()

return {
"vqa-v2": vqa_prompt_fn,
Expand All @@ -191,6 +192,7 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]:
"tally-qa": tally_qa_prompt_fn,
"refcoco": bbox_refer_prompt_fn,
"ocid-ref": bbox_refer_prompt_fn,
"ai2d": ai2d_prompt_fn,
# Generic for GUI
"captioning": captioning_prompt_fn,
"bbox_pred": bbox_refer_prompt_fn,
Expand Down Expand Up @@ -300,7 +302,7 @@ def llava_contrast_caption_prompt_fn(caption: str) -> str:

return llava_contrast_caption_prompt_fn

def get_mc_prompt_fn(self, choices: List[str]) -> Callable[[str], str]:
def get_mc_prompt_fn(self) -> Callable[[str], str]:
"""Generates the full reference prompt for a multiple-choice question-answer task."""

# Conversation manager `self.conv` is not stateless! Need to reset on each construction!
Expand Down
14 changes: 8 additions & 6 deletions vlm_eval/models/prismatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]:
bbox_refer_prompt_fn = self.get_bbox_refer_chat_prompt_fn()
text_vqa_prompt_fn = self.get_vqa_chat_prompt_fn(uncertainty_aware=False)
captioning_prompt_fn = self.get_captioning_prompt_fn()
tally_qa_prompt_fn = self.get_mc_prompt_fn(choices=[str(i) for i in range(16)])
tally_qa_prompt_fn = self.get_mc_prompt_fn()
ai2d_prompt_fn = self.get_mc_prompt_fn()

return {
"vqa-v2": vqa_prompt_fn,
Expand All @@ -89,6 +90,7 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]:
"tally-qa": tally_qa_prompt_fn,
"refcoco": bbox_refer_prompt_fn,
"ocid-ref": bbox_refer_prompt_fn,
"ai2d": ai2d_prompt_fn,
# Generic for GUI
"captioning": captioning_prompt_fn,
"bbox_pred": bbox_refer_prompt_fn,
Expand Down Expand Up @@ -171,15 +173,15 @@ def contrast_caption_prompt_fn(caption: str) -> str:

return contrast_caption_prompt_fn

def get_mc_prompt_fn(self, choices: List[str]) -> Callable[[str], str]:
def get_mc_prompt_fn(self) -> Callable[[str], str]:
"""Generates the full reference prompt for a multiple choice question-answering task."""
prompt_builder_fn = self.model.get_prompt_builder

# Create Choice String
assert len(choices) <= 26, "Too many answer choices vs. possible letters in the alphabet!"
choice_str = "\n".join([f"{chr(ord('A') + idx)}. {choice}" for idx, choice in enumerate(choices)])
def mc_prompt_fn(question: str, choices: List[str]) -> str:
# Create Choice String
assert len(choices) <= 26, "Too many answer choices vs. possible letters in the alphabet!"
choice_str = "\n".join([f"{chr(ord('A') + idx)}. {choice}" for idx, choice in enumerate(choices)])

def mc_prompt_fn(question: str) -> str:
# Use Default Prompt (same as LLaVa-v1.5)
prompt_builder = prompt_builder_fn()
q_prompt = f"\n{question}\n{choice_str}"
Expand Down
2 changes: 2 additions & 0 deletions vlm_eval/tasks/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vlm_eval.tasks.harnesses.vizwiz import VizWizIndexDataset, build_vizwiz_indices
from vlm_eval.tasks.harnesses.vqav2 import VQAv2IndexDataset, build_vqav2_indices
from vlm_eval.tasks.harnesses.vsr import VSRIndexDataset, build_vsr_indices
from vlm_eval.tasks.harnesses.ai2d import AI2DIndexDataset, build_ai2d_indices

# Initialize Overwatch =>> Wraps `logging.Logger`
overwatch = initialize_overwatch(__name__)
Expand All @@ -37,6 +38,7 @@
"refcoco": {"build_indices": build_refcoco_indices, "get_index_datasets": RefCOCOIndexDataset},
"ocid-ref": {"build_indices": build_ocidref_indices, "get_index_datasets": OCIDRefIndexDataset},
"tally-qa": {"build_indices": build_tallyqa_indices, "get_index_datasets": TallyQAIndexDataset},
"ai2d": {"build_indices": build_ai2d_indices, "get_index_datasets": AI2DIndexDataset},

# fmt: on
}
Expand Down
3 changes: 3 additions & 0 deletions vlm_eval/tasks/harnesses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .vizwiz import VizWizScorer, VizWizTaskRunner
from .vqav2 import VQAv2Scorer, VQAv2TaskRunner
from .vsr import VSRScorer, VSRTaskRunner
from .ai2d import AI2DScorer, AI2DTaskRunner


# === Protocol Definitions ===
Expand All @@ -36,6 +37,7 @@ def score(self, model_id: str) -> Dict[str, float]:
"tally-qa": TallyQATaskRunner,
"refcoco": RefCOCOTaskRunner,
"ocid-ref": OCIDRefTaskRunner,
"ai2d": AI2DTaskRunner,
}

# === Score Function Dispatch by Dataset Family ===
Expand All @@ -49,6 +51,7 @@ def score(self, model_id: str) -> Dict[str, float]:
"tally-qa": TallyQAScorer,
"refcoco": RefCOCOScorer,
"ocid-ref": OCIDRefScorer,
"ai2d": AI2DScorer,
}


Expand Down
Loading

0 comments on commit 118921e

Please sign in to comment.