Skip to content

Commit

Permalink
Finalize AI2D
Browse files Browse the repository at this point in the history
  • Loading branch information
siddk committed Mar 27, 2024
1 parent 118921e commit 2092905
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 23 deletions.
7 changes: 4 additions & 3 deletions scripts/datasets/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@dataclass
class DatasetPreparationConfig:
# fmt: off
dataset_family: str = "tally-qa" # Dataset family to prepare
dataset_family: str = "ai2d" # Dataset family to prepare

# Processing Parameters
create_slim_dataset: bool = True # Whether to create "slim" (minified) dataset(s)
Expand All @@ -40,11 +40,12 @@ class DatasetPreparationConfig:

# Path Parameters
root_dir: Path = Path( # Path to root directory for storing datasets
"/home/ubuntu/datasets/vlm-evaluation"
# "datasets/vlm-evaluation"
"/mnt/fsx/skaramcheti/datasets/vlm-evaluation"
)

# HF Hub Credentials (for LLaMa-2)
hf_token: Union[str, Path] = Path(".hf_token") # Env Variable or Path to HF Token (for Winoground)
hf_token: Union[str, Path] = Path(".hf_token") # Env Variable or Path to HF Token

# Randomness
seed: int = 21 # Random Seed (for slim datasets, augmentations)
Expand Down
16 changes: 9 additions & 7 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@ class EvaluationConfig:

# DatasetConfig from `vlm_eval/conf/datasets.py`; override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
dataset: DatasetConfig = field(
default_factory=DatasetConfig.get_choice_class(DatasetRegistry.TALLYQA_SUBSAMPLED.dataset_id)
default_factory=DatasetConfig.get_choice_class(DatasetRegistry.AI2D_FULL.dataset_id)
)

# === Model Parameters =>> Prismatic ===
model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
model_id: Optional[str] = None # Model ID to load and run (instance of `model_family`)
model_dir: Optional[Path] = None # Path to model checkpoint to load --> should be self-contained
model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
model_id: Optional[str] = ( # Model ID to load and run (instance of `model_family`)
"prism-clip+7b"
)
model_dir: Optional[Path] = None # Path to model checkpoint to load --> should be self-contained

# === Model Parameters =>> Official LLaVa ===
# model_family: str = "llava-v15"
# model_id: str = "llava-v1.5-13b"
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
# model_id: str = "llava-v1.5-7b"
# model_dir: Path = "liuhaotian/llava-v1.5-7b"

# === Model Parameters =>> Official InstructBLIP ===
# model_family: str = "instruct-blip"
Expand All @@ -58,7 +60,7 @@ class EvaluationConfig:

# Artifact Parameters
results_dir: Path = Path( # Path to results directory (writing predicted output, metrics)
"/home/ubuntu/prismatic-vlms/results"
"results"
)

# HF Hub Credentials (for LLaMa-2)
Expand Down
8 changes: 4 additions & 4 deletions scripts/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ class ScoreConfig:

# DatasetConfig from `vlm_eval/conf/datasets.py`; override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
dataset: DatasetConfig = field(
default_factory=DatasetConfig.get_choice_class(DatasetRegistry.TEXTVQA_SLIM.dataset_id)
default_factory=DatasetConfig.get_choice_class(DatasetRegistry.AI2D_FULL.dataset_id)
)

# === Model Parameters =>> Prismatic ===
model_id: str = "resize-naive-clip-vit-l-14-336px-no-align-llama2pure+7b+stage-finetune+x7" # Model ID to load and run (instance of `model_family`)
model_id: str = "prism-clip+7b" # Model ID to load and run (instance of `model_family`)

# === Model Parameters =>> Official LLaVa ===
# model_id: str = "llava-v1.5-13b"
# model_id: str = "llava-v1.5-7b"

# === Model Parameters =>> Official InstructBLIP ===
# model_id: str = "instructblip-vicuna-7b"
Expand All @@ -50,7 +50,7 @@ class ScoreConfig:

# Artifact Parameters
results_dir: Path = Path( # Path to results directory (writing predicted output, metrics)
"/home/ubuntu/prismatic-vlms/results"
"results"
)

# fmt: on
Expand Down
3 changes: 1 addition & 2 deletions vlm_eval/conf/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ class AI2DFullDatasetConfig(DatasetConfig):
expected_examples: int = 15501

root_dir: Path = Path("../../datasets/vlm-evaluation")
index_file: Path = Path("datasets/ai2d/metadata.json")
index_file: Path = Path("datasets/ai2d/metadata-full.json")
annotations_file: Path = Path("datasets/ai2d/metadata-full.json")


Expand All @@ -365,7 +365,6 @@ class AI2DSlimDatasetConfig(DatasetConfig):
annotations_file: Path = Path("datasets/ai2d/metadata-slim-1024.json")



# === Define a Dataset Registry Enum for Reference / Validation =>> all *new* datasets must be added here! ===
@unique
class DatasetRegistry(Enum):
Expand Down
13 changes: 6 additions & 7 deletions vlm_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,10 @@ def get_mc_prompt_fn(self) -> Callable[[str], str]:
# Conversation manager `self.conv` is not stateless! Need to reset on each construction!
self.conv = conv_templates[self.conv_mode].copy()

# Create Choice String
assert len(choices) <= 26, "Too many answer choices vs. possible letters in the alphabet!"
choice_str = "\n".join([f"{chr(ord('A') + idx)}. {choice}" for idx, choice in enumerate(choices)])

# Different LLaVa Models handle <IMAGE> token insertion differently; we support both LLaVa v1 and v1.5!
# => Ref (v1): https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa_science.py#L53
# => Ref (v1.5): https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md#evaluate-on-custom-datasets
q_prompt = DEFAULT_IMAGE_TOKEN + "\n" + "{question}\n" + choice_str
q_prompt = DEFAULT_IMAGE_TOKEN + "\n" + "{question}\n{choice_str}"
if self.model_id.startswith("llava-v1.5"):
q_prompt += "\nAnswer with the option's letter from the given choices directly."

Expand All @@ -326,8 +322,11 @@ def get_mc_prompt_fn(self) -> Callable[[str], str]:
# Get full chat prompt template function --> insert question with `template.format(question=<QUESTION>)`
prompt_template = self.conv.get_prompt()

def llava_mc_prompt_fn(question: str) -> str:
return prompt_template.format(question=question)
def llava_mc_prompt_fn(question: str, choices: List[str]) -> str:
assert len(choices) <= 26, "Too many answer choices vs. possible letters in the alphabet!"
choice_str = "\n".join([f"{chr(ord('A') + idx)}. {choice}" for idx, choice in enumerate(choices)])

return prompt_template.format(question=question, choice_str=choice_str)

return llava_mc_prompt_fn

Expand Down

0 comments on commit 2092905

Please sign in to comment.