Skip to content

Commit

Permalink
remove length normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Aug 25, 2023
1 parent fef4df2 commit 1d84adf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
30 changes: 13 additions & 17 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,18 +689,17 @@ def evaluate_captioning(
args.batch_size,
)

# subset of the training set to sample context images from
query_set = utils.get_query_set(train_dataset, args.query_set_size)
if args.rices:
rices_dataset = RICES(
train_dataset,
query_set,
eval_model.device,
args.batch_size,
cached_features=cached_features,
cached_features=cached_features[query_set.indices],
vision_encoder_path=args.rices_vision_encoder_path,
vision_encoder_pretrained=args.rices_vision_encoder_pretrained,
)
else:
# subset of the training set to sample context images from
query_set = utils.get_query_set(train_dataset, args.query_set_size)

utils.random_seed(seed, args.rank)
predictions = defaultdict()
Expand Down Expand Up @@ -882,17 +881,16 @@ def evaluate_vqa(
args.batch_size,
)

query_set = utils.get_query_set(train_dataset, args.query_set_size)
if args.rices:
rices_dataset = RICES(
train_dataset,
query_set,
eval_model.device,
args.batch_size,
cached_features=cached_features,
cached_features=cached_features[query_set.indices],
vision_encoder_path=args.rices_vision_encoder_path,
vision_encoder_pretrained=args.rices_vision_encoder_pretrained,
)
else:
query_set = utils.get_query_set(train_dataset, args.query_set_size)

utils.random_seed(seed, args.rank)
predictions = []
Expand Down Expand Up @@ -1119,19 +1117,17 @@ def evaluate_classification(
args.batch_size,
)

query_set = utils.get_query_set(train_dataset, args.query_set_size)
assert hasattr(query_set, 'class_id_array')
if args.rices:
rices_dataset = RICES(
train_dataset,
query_set,
eval_model.device,
args.batch_size,
cached_features=cached_features,
cached_features=cached_features[query_set.indices],
vision_encoder_path=args.rices_vision_encoder_path,
vision_encoder_pretrained=args.rices_vision_encoder_pretrained,
)
else:
# subset of the training set to sample context images from
query_set = utils.get_query_set(train_dataset, args.query_set_size)
assert hasattr(query_set, 'class_id_array')

utils.random_seed(seed, args.rank)
predictions = []
Expand Down Expand Up @@ -1195,7 +1191,7 @@ def evaluate_classification(
batch_images,
all_class_names,
use_cache=(not no_kv_caching),
normalize_length=True,
normalize_length=False,
)
)

Expand All @@ -1215,7 +1211,7 @@ def evaluate_classification(

# dev: print some results
if batch_idx == 0:
print(list(zip(batch_text, predicted_classnames[:1]))[:5])
print("Context:", batch_text[0], "\n", "Generated:", predicted_classnames[0][0], "\n", "True:", batch["class_name"][0])

# compute accuracy
for i, topk in enumerate(predicted_classnames):
Expand Down
6 changes: 5 additions & 1 deletion open_flamingo/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,18 @@ def sample_examples_from_class(dataset, y, num_samples, replace_if_insufficient=
def get_query_set(train_dataset, query_set_size):
"""
Get a subset of the training dataset to use as the query set. Returns a torch Subset.
Adds the "indices" attribute containing the indices of each example in the original set.
"""
if query_set_size == -1: return train_dataset
if query_set_size == -1:
train_dataset.indices = np.arange(len(train_dataset))
return train_dataset
query_set_indices = np.random.choice(len(train_dataset), query_set_size, replace=False)
query_set = Subset(train_dataset, query_set_indices)
if hasattr(train_dataset, "class_id_array"):
query_set.class_id_array = train_dataset.class_id_array[query_set_indices]
if len(np.unique(query_set.class_id_array)) != len(np.unique(train_dataset.class_id_array)):
print(f"Warning: query set does not contain examples from all classes; {len(np.unique(query_set.class_id_array))} remaining classes.")
query_set.indices = query_set_indices
return query_set


Expand Down

0 comments on commit 1d84adf

Please sign in to comment.