Skip to content

Commit

Permalink
Support MMVP
Browse files Browse the repository at this point in the history
  • Loading branch information
czczup committed Feb 4, 2024
1 parent 21baf4c commit 8fff25b
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## News🚀🚀🚀

- `2024/02/04`: [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) achieves 44.67% on [MMVP](https://github.com/tsb0601/MMVP), higher than GPT-4V!
- `2024/01/27`: We release 448 resolution model, achieving 76.6 on MMBench dev, see [here](https://github.com/OpenGVLab/InternVL/tree/main/internvl_chat#-evaluation-chinese-models).
- `2024/01/24`: InternVL-Chat-V1.1 is released, it supports Chinese and has stronger OCR capability, see [here](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) or try our [demo](https://internvl.opengvlab.com/).
- `2024/01/16`: We release our [customized mmcv/mmsegmentation/mmdetection code](https://github.com/OpenGVLab/InternVL-MMDetSeg), integrated with DeepSpeed, which can be used for training large-scale object detection and semantic segmentation models.
Expand Down
129 changes: 129 additions & 0 deletions clip_benchmark/evaluate_vlm_mmvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import csv
import os

import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor


def benchmark_model(model_name, benchmark_dir, device='cuda'):
# model_path = '/mnt/petrelfs/share_data/wangwenhai/llm/internvl_14b_224px'
model_path = 'OpenGVLab/InternVL-14B-224px'
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True).cuda().eval()
preprocess = CLIPImageProcessor.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, add_eos_token=True)
tokenizer.pad_token_id = 0 # set pad_token_id to 0
image_dir = os.path.join(benchmark_dir, 'MLLM_VLM Images')
csv_file = os.path.join(benchmark_dir, 'Questions.csv')

csv_outfile = open('output.csv', 'w', newline='')
csv_writer = csv.writer(csv_outfile)
csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header

categories = [
'Orientation and Direction', 'Presence of Specific Features',
'State and Condition', 'Quantity and Count',
'Positional and Relational Context', 'Color and Appearance',
'Structural Characteristics', 'Texts',
'Viewpoint and Perspective'
]

pair_accuracies = {category: 0 for category in categories}
num_pairs = 0

with open(csv_file, 'r') as f:
reader = csv.reader(f)
next(reader) # skip header
for i, row in tqdm(enumerate(reader)):
qid1, qtype1, statement1 = row

# Get next row for the pair
row = next(reader, None)
if not row:
break
qid2, qtype2, statement2 = row

qid1, qid2 = int(qid1), int(qid2)

img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg'))
img1 = img1.resize((224, 224))
img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg'))
img2 = img2.resize((224, 224))

prefix = 'summarize:'
# text1 = prefix + 'a photo of ' + statement1
# text2 = prefix + 'a photo of ' + statement2
text1 = prefix + statement1
text2 = prefix + statement2

text1 = tokenizer(text1, return_tensors='pt', max_length=80,
truncation=True, padding='max_length').input_ids.cuda()
text2 = tokenizer(text2, return_tensors='pt', max_length=80,
truncation=True, padding='max_length').input_ids.cuda()

img1 = preprocess(images=img1, return_tensors='pt').pixel_values.to(torch.float16).cuda()
img2 = preprocess(images=img2, return_tensors='pt').pixel_values.to(torch.float16).cuda()
imgs = torch.cat((img1, img2), dim=0)

with torch.no_grad():
logits_per_image1, logits_per_text1 = model(image=imgs, text=text1, mode=model_name)
logits_per_image2, logits_per_text2 = model(image=imgs, text=text2, mode=model_name)

probs1 = logits_per_text1.float().softmax(dim=-1).cpu().numpy()
probs2 = logits_per_text2.float().softmax(dim=-1).cpu().numpy()

img1_score1 = probs1[0][0]
img1_score2 = probs2[0][0]

pred1 = 'img1' if img1_score1 > 0.5 else 'img2'
pred2 = 'img1' if img1_score2 > 0.5 else 'img2'

gt1 = 'img1' if qid1 % 2 == 1 else 'img2'
gt2 = 'img1' if qid2 % 2 == 1 else 'img2'

csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2])

current_category = categories[num_pairs // 15]
if pred1 == gt1 and pred2 == gt2:
pair_accuracies[current_category] += 1
num_pairs += 1

csv_outfile.close()

# Calculate percentage accuracies
for category in pair_accuracies:
pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100

return pair_accuracies


parser = argparse.ArgumentParser(description='Process a directory path.')

# Adding an argument for the directory path
parser.add_argument('--directory', type=str, help='The path to the directory')

# Parsing the arguments
args = parser.parse_args()

# InternVL models
models = ['InternVL-C', 'InternVL-G']

results = {f'{model}': benchmark_model(model, args.directory) for model in models}

print(results)

# Convert results to format suitable for star plot
categories = results[list(results.keys())[0]].keys()
print(f'categories: {categories}')
data = {'Categories': list(categories)}
print(f'data: {data}')
for model in list(results.keys()):
data[model] = [results[model][category] for category in categories]
print(f'data: {data}')
37 changes: 34 additions & 3 deletions internvl_chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ Coming Soon

**MultiModal Benchmark**

| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE |
| --------------------------------------------------------------------------------- | -------------- | ---------------------- | ------------------------- | ---- |
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 1672.3 / 341.1 | 76.6 / 75.4 | 71.5 / 70.1 | 87.2 |
| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE | MMVP |
| --------------------------------------------------------------------------------- | -------------- | ---------------------- | ------------------------- | ---- | ---- |
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 1672.3 / 341.1 | 76.6 / 75.4 | 71.5 / 70.1 | 87.2 | 44.7 |

| model | MMMU<sub>val/test</sub> | CMMMU<sub>val/test</sub> | Tiny<sub>LVLM</sub> | LLaVA<sub>bench</sub> | MM-Vet |
| --------------------------------------------------------------------------------- | ----------------------- | ------------------------ | ------------------- | --------------------- | ------ |
Expand Down Expand Up @@ -284,6 +284,13 @@ data
│ └── Sociology
├── mm-vet
│ └── images/
├── MMVP
│ ├── MMVP Images/
│ ├── Questions.csv
│ └── Questions.xlsx
├── MMVP_VLM
│ ├── MLLM_VLM Images/
│ └── Questions.csv
```
</details>
Expand Down Expand Up @@ -974,3 +981,27 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mmvet
```

</details>

#### [MMVP](https://github.com/tsb0601/MMVP)

<details>
<summary>Data Preparation</summary>

```bash
cd data
git lfs install
git clone https://huggingface.co/datasets/MMVP/MMVP
git clone https://huggingface.co/datasets/MMVP/MMVP_VLM
cd ..
```

</details>

<details>
<summary>Evaluation</summary>

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mmvp
```

</details>
Loading

0 comments on commit 8fff25b

Please sign in to comment.