Skip to content

Commit

Permalink
some more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Mar 17, 2024
1 parent 0b1c926 commit 79ad152
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 153 deletions.
8 changes: 4 additions & 4 deletions open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

SUPPORTED_TASKS = [
"coco",
"flickr",
"flickr30",
"vqav2",
"ok_vqa",
"okvqa",
"vizwiz",
"textvqa",
"hateful_memes",
Expand Down Expand Up @@ -87,15 +87,15 @@ def __init__(
self.image_dir_path = image_dir_path
self.is_train = is_train
self.dataset_name = dataset_name
if self.dataset_name in {"vqav2", "ok_vqa"}:
if self.dataset_name in {"vqav2", "okvqa"}:
self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
assert self.img_coco_split in {"train2014", "val2014", "test2015"}

def __len__(self):
return len(self.questions)

def get_img_path(self, question):
if self.dataset_name in {"vqav2", "ok_vqa"}:
if self.dataset_name in {"vqav2", "okvqa"}:
return os.path.join(
self.image_dir_path,
f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg"
Expand Down
21 changes: 10 additions & 11 deletions open_flamingo/eval/eval_models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
class EvalModel(BaseEvalModel):
"""BLIP-2 model evaluation."""

def __init__(self, model_args, init_on_device=False):
super().__init__(model_args, init_on_device)
with self.init_ctx:
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
self.model = Blip2ForConditionalGeneration.from_pretrained(
model_args["lm_path"]
)
self.tokenizer = self.processor.tokenizer
def __init__(self, model_args):
super().__init__(model_args)
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
self.model = Blip2ForConditionalGeneration.from_pretrained(
model_args["lm_path"]
)
self.tokenizer = self.processor.tokenizer

self._check_init()

@property
def required_args(self):
return ["processor_path", "lm_path"]
Expand Down Expand Up @@ -100,7 +99,7 @@ def get_outputs(
def get_vqav2_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_ok_vqa_prompt(self, question, answer=None) -> str:
def get_okvqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_vizwiz_prompt(self, question, answer=None) -> str:
Expand All @@ -112,5 +111,5 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
def get_coco_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
def get_flickr30_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"
21 changes: 10 additions & 11 deletions open_flamingo/eval/eval_models/idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
class EvalModel(BaseEvalModel):
"""IDEFICS model evaluation."""

def __init__(self, model_args, init_on_device=False):
super().__init__(model_args, init_on_device)
with self.init_ctx:
self.model = IdeficsForVisionText2Text.from_pretrained(
model_args["lm_path"]
)
self.processor = AutoProcessor.from_pretrained(model_args["processor_path"])
self.tokenizer = self.processor.tokenizer
def __init__(self, model_args):
super().__init__(model_args)
self.model = IdeficsForVisionText2Text.from_pretrained(
model_args["lm_path"]
)
self.processor = AutoProcessor.from_pretrained(model_args["processor_path"])
self.tokenizer = self.processor.tokenizer

self._check_init()

@property
def required_args(self):
return ["lm_path", "processor_path"]
Expand Down Expand Up @@ -171,7 +170,7 @@ def get_vqav2_prompt(self, question, answer=None) -> str:
# TODO: handle prefix prompts
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_ok_vqa_prompt(self, question, answer=None) -> str:
def get_okvqa_prompt(self, question, answer=None) -> str:
# TODO: handle prefix prompts
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

Expand All @@ -187,6 +186,6 @@ def get_coco_prompt(self, caption=None) -> str:
# TODO: handle prefix prompts
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
def get_flickr30_prompt(self, caption=None) -> str:
# TODO: handle prefix prompts
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
57 changes: 33 additions & 24 deletions open_flamingo/eval/eval_models/open_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,40 @@
class EvalModel(BaseEvalModel):
"""OpenFlamingo model evaluation."""

def __init__(self, model_args, init_on_device=False):
super().__init__(model_args, init_on_device)
def __init__(self, model_args):
super().__init__(model_args)

if model_args["model_family"] == "openflamingo":
assert "cross_attn_every_n_layers" in model_args, "cross_attn_every_n_layers is required for Flamingo models"
else:
assert "cross_attn_every_n_layers" not in model_args, "cross_attn_every_n_layers is only for Flamingo models"

# initialize the model
with self.init_ctx:
(
self.model,
self.image_processor,
self.tokenizer,
) = create_model_and_transforms(
clip_vision_encoder_path=model_args["vision_encoder_path"],
clip_vision_encoder_pretrained=model_args["vision_encoder_pretrained"],
lang_model_path=model_args["lm_path"],
tokenizer_path=model_args["tokenizer_path"],
model_family=model_args["model_family"],
cross_attn_every_n_layers=int(
model_args.get("cross_attn_every_n_layers", 1)
),
)
additional_kwargs = (
{"cross_attn_every_n_layers": model_args.get("cross_attn_every_n_layers", 1)}
if model_args["model_family"] == "flamingo"
else {}
)
(
self.model,
self.image_processor,
self.tokenizer,
) = create_model_and_transforms(
clip_vision_encoder_path=model_args["vision_encoder_path"],
clip_vision_encoder_pretrained=model_args["vision_encoder_pretrained"],
lang_model_path=model_args["lm_path"],
tokenizer_path=model_args["tokenizer_path"],
model_family=model_args["model_family"],
**additional_kwargs,
)

# load the checkpoint
checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu")
if "model_state_dict" in checkpoint:
checkpoint = checkpoint["model_state_dict"]
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
self.model.load_state_dict(checkpoint, strict=False)
self.model_family = model_args["model_family"]

self._check_init()

Expand All @@ -46,11 +55,10 @@ def required_args(self):
"""Return list of required arguments to initialize model."""
return [
"vision_encoder_path",
"model_familyl",
"model_family",
"lm_path",
"checkpoint_path",
"tokenizer_path",
"cross_attn_every_n_layers",
"vision_encoder_pretrained",
]

Expand Down Expand Up @@ -170,8 +178,9 @@ def get_outputs(
**decode_kwargs,
)

# Extract only the new generated tokens
outputs = outputs[:, len(input_ids[0]) :]
if self.model_family == "flamingo":
# Extract only the new generated tokens
outputs = outputs[:, len(input_ids[0]) :]
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

def get_rank_classifications(
Expand Down Expand Up @@ -270,8 +279,8 @@ def get_rank_classifications(
def get_vqav2_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_ok_vqa_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
def get_okvqa_prompt(self, question, answer=None) -> str:
return f"<image>Instruct: {question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_vizwiz_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
Expand All @@ -282,7 +291,7 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
def get_coco_prompt(self, caption=None) -> str:
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
def get_flickr30_prompt(self, caption=None) -> str:
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_imagenet_prompt(self, label=None) -> str:
Expand Down
33 changes: 18 additions & 15 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
help="Whether to evaluate on VQAV2.",
)
parser.add_argument(
"--eval_ok_vqa",
"--eval_okvqa",
action="store_true",
default=False,
help="Whether to evaluate on OK-VQA.",
Expand Down Expand Up @@ -408,10 +408,8 @@ def main():
model_args["device"] = device_id

# initialize model
eval_model = get_eval_model(args.model, model_args, init_on_device=False)
eval_model.init_distributed(
local_rank=args.local_rank,
)
eval_model = get_eval_model(args.model, model_args)
eval_model.init_distributed()

# Validate args
if args.model in ZERO_SHOT_ONLY_MODELS and args.shots != [0]:
Expand Down Expand Up @@ -504,7 +502,7 @@ def main():
}
)

if args.eval_ok_vqa:
if args.eval_okvqa:
print("Evaluating on OK-VQA...")

# load cached demonstration features for RICES
Expand All @@ -523,7 +521,7 @@ def main():
eval_model=eval_model,
num_shots=shot,
seed=seed,
dataset_name="ok_vqa",
dataset_name="okvqa",
cached_features=cached_features,
)
if args.rank == 0:
Expand Down Expand Up @@ -919,7 +917,7 @@ def evaluate_vqa(
seed: int = 42,
min_new_tokens: int = 0,
max_new_tokens: int = 5,
num_beams: int = 3,
num_beams: int = 5,
length_penalty: float = 0.0,
num_shots: int = 8,
dataset_name: str = "vqav2",
Expand All @@ -936,13 +934,13 @@ def evaluate_vqa(
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
num_shots (int, optional): number of shots to use. Defaults to 8.
dataset_name (string): type of vqa dataset: currently supports vqav2, ok_vqa. Defaults to vqav2.
dataset_name (string): type of vqa dataset: currently supports vqav2, okvqa. Defaults to vqav2.
cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None.
Returns:
float: accuracy score
"""

if dataset_name == "ok_vqa":
if dataset_name == "okvqa":
train_image_dir_path = args.ok_vqa_train_image_dir_path
train_questions_json_path = args.ok_vqa_train_questions_json_path
train_annotations_json_path = args.ok_vqa_train_annotations_json_path
Expand Down Expand Up @@ -989,7 +987,7 @@ def evaluate_vqa(
dataset_name=dataset_name,
)

effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
effective_num_shots = num_shots #utils.compute_effective_num_shots(num_shots, args.model)

np.random.seed(seed)
test_dataloader = utils.prepare_eval_samples(
Expand All @@ -1012,6 +1010,11 @@ def evaluate_vqa(

utils.random_seed(seed, args.rank)
predictions = []

get_vqa_prompt = getattr(
eval_model, f"get_{dataset_name}_prompt"
)

for batch in tqdm(
test_dataloader,
desc=f"Running inference {dataset_name}",
Expand All @@ -1034,7 +1037,7 @@ def evaluate_vqa(

context_text = "".join(
[
eval_model.get_vqa_prompt(
get_vqa_prompt(
question=x["question"], answer=x["answers"][0]
)
+ "\n"
Expand All @@ -1047,9 +1050,9 @@ def evaluate_vqa(
context_text = context_text.replace("<image>", "")

batch_text.append(
context_text + eval_model.get_vqa_prompt(question=batch["question"][i])
context_text + get_vqa_prompt(question=batch["question"][i])
)

outputs = eval_model.get_outputs(
batch_images=batch_images,
batch_text=batch_text,
Expand Down Expand Up @@ -1186,7 +1189,7 @@ def evaluate_classification(

class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names))

effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
effective_num_shots = num_shots #utils.compute_effective_num_shots(num_shots, args.model)

np.random.seed(seed)
test_dataloader = utils.prepare_eval_samples(
Expand Down
1 change: 1 addition & 0 deletions open_flamingo/eval/ok_vqa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def stem(self, input_string):

def postprocess_ok_vqa_generation(predictions) -> str:
prediction = re.split("Question|Answer|Short", predictions, 1)[0]
prediction = prediction.split(". ", 1)[0]
prediction = re.split(", ", prediction, 1)[0]
prediction_stem = stemmer.stem(prediction)
return prediction_stem
1 change: 1 addition & 0 deletions open_flamingo/eval/vqa_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,5 +556,6 @@ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_p

def postprocess_vqa_generation(predictions):
answer = re.split("Question|Answer|Short", predictions, 1)[0]
answer = answer.split(". ", 1)[0]
answer = re.split(", ", answer, 1)[0]
return answer
8 changes: 0 additions & 8 deletions open_flamingo/src/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ def set_trainable(self):
"""
self.requires_grad_(False)
self.vision_tokenizer.requires_grad_(True)
self.lang_model.get_output_embeddings().set_requires_grad(
require_regular_grad=False,
require_additional_grad=True,
)
self.lang_model.get_input_embeddings().set_requires_grad(
require_regular_grad=False,
require_additional_grad=True,
)

def _should_apply_weight_decay(self, parameter_name):
"""BLIP applies 0.05 weight decay to everything"""
Expand Down
11 changes: 8 additions & 3 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ def create_model_and_transforms(
clip_vision_encoder_path,
pretrained=clip_vision_encoder_pretrained,
cache_dir=cache_dir,
force_image_size=490,
)
vision_encoder.visual.output_tokens = True
vision_encoder = vision_encoder.visual
vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
if "SigLIP" in clip_vision_encoder_path: # SigLIP models have a different config format
if "SigLIP" in clip_vision_encoder_path or "EVA" in clip_vision_encoder_path: # SigLIP models have a different config format
vis_hidden_dim = vision_encoder_config["embed_dim"]
else:
vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]
Expand All @@ -74,8 +75,9 @@ def create_model_and_transforms(
trust_remote_code=True,
cache_dir=cache_dir,
)
if text_tokenizer.pad_token is None:
text_tokenizer.pad_token_id = text_tokenizer.eos_token_id
if text_tokenizer.pad_token is None or text_tokenizer.pad_token == text_tokenizer.eos_token:
# add a pad token if it doesn't exist
text_tokenizer.add_special_tokens({"pad_token": "<pad>"})

# load langauge model
lang_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -150,6 +152,9 @@ def _infer_decoder_layers_attr_name(model):
"gemma": "model.layers",
"phi": "model.layers",
"minicpm": "model.layers",
"stablelm": "model.layers",
"qwen": "model.layers",
"mistral": "model.layers"
}


Expand Down
Loading

0 comments on commit 79ad152

Please sign in to comment.