Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ONNX Export #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README-ONNX.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
```bash
# 1. Clone the project
git clone https://github.com/nhanerc/GroundingDINO.git
cd GroundingDINO
git checkout main

# 2. Use docker container
docker run -it --rm --gpus all -v `pwd`:/workspace -w /workspace nvcr.io/nvidia/pytorch:24.07-py3
pip install -e .
pip uninstall opencv opencv-python opencv-python-headless -y
pip install opencv-python-headless==4.8.0.74 onnxruntime

# 2. This command will generate `pred.jpg` in `result` folder
mkdir -p weights && wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights
CUDA_VISIBLE_DEVICES=0 python demo/inference_on_a_image.py -c groundingdino/config/GroundingDINO_SwinT_OGC.py -p weights/groundingdino_swint_ogc.pth -i assets/demo1.jpg -o result -t "bear"

# 3. Open another terminal (outside the docker container): git checkout onnx

# 4. Go back to the docker container, ONNX export => `dino.onnx` in `weights` folder
python onnx/export.py -c groundingdino/config/GroundingDINO_SwinT_OGC.py -p weights/groundingdino_swint_ogc.pth -o weights

# 5. This command will generate `output.jpg` in `result` folder
python onnx/inference.py -p weights/dino.onnx -i assets/demo1.jpg -o result --box_threshold 0.35 -t "bear"
```
55 changes: 15 additions & 40 deletions groundingdino/models/GroundingDINO/groundingdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List
from typing import List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -224,7 +224,12 @@ def set_image_features(self, features , poss):
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)

def forward(self, samples: NestedTensor, targets: List = None, **kw):
def forward(self, samples: NestedTensor,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
text_self_attention_masks: Optional[torch.Tensor] = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
Expand All @@ -239,45 +244,17 @@ def forward(self, samples: NestedTensor, targets: List = None, **kw):
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if targets is None:
captions = kw["captions"]
else:
captions = [t["caption"] for t in targets]

# encoder texts
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
samples.device
)
(
text_self_attention_masks,
position_ids,
cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, self.specical_tokens, self.tokenizer
)

if text_self_attention_masks.shape[1] > self.max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]

# extract text embeddings
if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
else:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder = tokenized
tokenized_for_encoder ={}

tokenized_for_encoder["input_ids"] = input_ids
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
tokenized_for_encoder["token_type_ids"] = token_type_ids

bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768

encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
text_token_mask = tokenized.attention_mask.bool() # bs, 195
text_token_mask = attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask

Expand Down Expand Up @@ -359,9 +336,7 @@ def forward(self, samples: NestedTensor, targets: List = None, **kw):
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
unset_image_tensor = kw.get('unset_image_tensor', True)
if unset_image_tensor:
self.unset_image_tensor() ## If necessary

return out

@torch.jit.unused
Expand Down
105 changes: 105 additions & 0 deletions onnx/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import argparse

import torch

from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict


class Wrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model

def forward(self, *args, **kwargs):
outputs = self.model(*args, **kwargs)
proba = outputs["pred_logits"].sigmoid()
boxes = outputs["pred_boxes"]
return proba, boxes


def load_model(model_config_path: str, model_checkpoint_path: str, cpu_only: bool = False) -> torch.nn.Module:
args = SLConfig.fromfile(model_config_path)
args.device = "cuda" if not cpu_only else "cpu"

args.use_checkpoint = False
args.use_transformer_ckpt = False

model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return Wrapper(model)


def export(model: torch.nn.Module, output_dir: str) -> None:
caption = "the running dog ."
input_ids = model.model.tokenizer([caption], return_tensors="pt")["input_ids"]
position_ids = torch.tensor([[0, 0, 1, 2, 3, 0]])
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0]])
attention_mask = torch.tensor([[True, True, True, True, True, True]])
text_token_mask = torch.tensor(
[
[
[True, False, False, False, False, False],
[False, True, True, True, True, False],
[False, True, True, True, True, False],
[False, True, True, True, True, False],
[False, True, True, True, True, False],
[False, False, False, False, False, True],
]
]
)

image = torch.randn(1, 3, 640, 800)
dynamic_axes = {
"input_ids": {1: "seq_len"},
"attention_mask": {1: "seq_len"},
"position_ids": {1: "seq_len"},
"token_type_ids": {1: "seq_len"},
"text_token_mask": {1: "seq_len", 2: "seq_len"},
}

torch.onnx.export(
model,
f=f"{output_dir}/dino.onnx",
args=(
image,
input_ids,
attention_mask,
position_ids,
token_type_ids,
text_token_mask,
),
input_names=[
"image",
"input_ids",
"attention_mask",
"position_ids",
"token_type_ids",
"text_token_mask",
],
output_names=["proba", "boxes"],
dynamic_axes=dynamic_axes,
opset_version=16,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser("Export Grounding DINO Model to IR", add_help=True)
parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
parser.add_argument("--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file")
parser.add_argument(
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
)
args = parser.parse_args()
# cfg
config_file = args.config_file # change the path of the model config file
checkpoint_path = args.checkpoint_path # change the path of the model
output_dir = args.output_dir

os.makedirs(output_dir, exist_ok=True)
model = load_model(config_file, checkpoint_path, cpu_only=True)
export(model, output_dir)
170 changes: 170 additions & 0 deletions onnx/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
import argparse
import typing as T

import cv2
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer


def draw_boxes_to_image(
image: np.ndarray,
boxes: np.ndarray,
phrases: T.List[str],
confs: T.List[float],
) -> None:
colors = set()
for box, phrase, conf in zip(boxes, phrases, confs):
x1, y1, x2, y2 = box.astype(np.int32)

color = np.random.randint(0, 256, 3)
while tuple(color) in colors:
color = np.random.randint(0, 256, 3)
colors.add(tuple(color))

image = cv2.rectangle(image, (x1, y1), (x2, y2), color.tolist(), 2)
image = cv2.putText(
image, f"{phrase} ({conf:.2f})", (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 255), 1
)


def get_phrases_from_posmap(
posmap: np.ndarray,
tokenized: T.Dict,
tokenizer: AutoTokenizer,
left_idx: int = 0,
right_idx: int = 255,
) -> str:
assert isinstance(posmap, np.ndarray), "posmap must be numpy array"
assert posmap.ndim == 1, "posmap must be 1-dim"

posmap[0 : left_idx + 1] = False
posmap[right_idx:] = False
non_zero_idx = np.nonzero(posmap)[0]
token_ids = tokenized["input_ids"][0, non_zero_idx]
return tokenizer.decode(token_ids)


def generate_masks_with_special_tokens_and_transfer_map(
tokenized: T.Dict[str, np.ndarray], special_tokens_list: T.List[int]
):
"""Generate attention mask between each pair of special tokens"""
input_ids = tokenized["input_ids"]
bs, num_token = input_ids.shape
assert bs == 1, "Batch size must be 1"
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = np.zeros((1, num_token), dtype=bool) # [bs, num_token]
for special_token in special_tokens_list:
special_tokens_mask = np.logical_or(special_tokens_mask, input_ids == special_token)

# get indexes of special tokens
rows, cols = np.nonzero(special_tokens_mask)

# generate attention mask and positional ids
attention_mask = np.expand_dims(np.eye(num_token, dtype=bool), axis=0) # [bs, num_token, num_token]
position_ids = np.zeros((1, num_token), dtype=np.int64) # [bs, num_token]
previous_col = 0
for row, col in zip(rows, cols):
if col == 0 or col == num_token - 1:
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = np.arange(0, col - previous_col)
previous_col = col
return attention_mask, position_ids


def infer(args) -> None:
# Load model
sess = ort.InferenceSession(args.model_path)

# Load image
image = cv2.imread(args.image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = 640, 800
image = cv2.resize(image, (w, h))

# Preprocess image
x = image / 255.0
x = (x - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, axis=0).astype(np.float32)

# Prompt text
caption = args.text_prompt
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."

# Preprocess text
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=True)
tokenized = tokenizer(caption, padding="longest", return_tensors="np")
specical_tokens = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])

max_text_len = 256
if tokenized["input_ids"].shape[1] > max_text_len:
for k in tokenized:
tokenized[k] = tokenized[k][:, :max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"].astype(bool)

(text_token_mask, position_ids) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens
)

# Run model
output = sess.run(
None,
{
"image": x,
"input_ids": tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"],
"position_ids": position_ids,
"token_type_ids": tokenized["token_type_ids"],
"text_token_mask": text_token_mask,
},
)

proba = output[0][0] # (nq, 256)
boxes = output[1][0] # (nq, 4)

# filter output
mask = proba.max(axis=1) > args.box_threshold
proba = proba[mask]
boxes = boxes[mask]

# get phrase
phrases, confs = [], []
for i, prob in enumerate(proba):
confs.append(prob.max())
phrase = get_phrases_from_posmap(prob > args.text_threshold, tokenized, tokenizer)
phrases.append(phrase)
# from 0..1 to 0..W, 0..H
boxes[i] = boxes[i] * [w, h, w, h]
# from xywh to xyxy
boxes[i][:2] -= boxes[i][2:] / 2
boxes[i][2:] += boxes[i][:2]

# Draw boxes
draw_boxes_to_image(image, boxes, phrases, confs)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
os.makedirs(args.output_dir, exist_ok=True)
cv2.imwrite(os.path.join(args.output_dir, "output.jpg"), image)


if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
parser.add_argument("--model_path", "-p", type=str, required=True, help="path to onnx file")
parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
parser.add_argument(
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
)

parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
args = parser.parse_args()

infer(args)