diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 402cefb7..4935d3e2 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -153,7 +153,7 @@ ) ## VQAV2, OK-VQA, VizWiz, TextVQA, GQA Datasets -for task in ['vqav2', 'ok_vqa', 'vizwiz', 'textvqa', 'gqa']: +for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa']: parser.add_argument( f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' else f"--{task}_train_image_dir_path", type=str, @@ -322,7 +322,7 @@ def main(): # load cached demonstration features for RICES if args.cached_demonstration_features is not None: cached_features = torch.load( - f"{args.cached_demonstration_features}/{'ok_vqa' if vqa_task=='okvqa' else vqa_task}.pkl", map_location="cpu" + f"{args.cached_demonstration_features}/{vqa_task}.pkl", map_location="cpu" ) else: cached_features = None @@ -603,7 +603,7 @@ def evaluate_vqa( var_args = vars(args) for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: if dataset_name == task: - task = task if task!="okvqa" else "ok_vqa" + task = task train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"] train_questions_json_path = var_args[f"{task}_train_questions_json_path"] train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] @@ -706,7 +706,7 @@ def evaluate_vqa( process_function = ( postprocess_ok_vqa_generation - if dataset_name == "ok_vqa" + if dataset_name == "okvqa" else postprocess_vqa_generation )