Skip to content

Commit

Permalink
started adding YOLOX Model
Browse files Browse the repository at this point in the history
  • Loading branch information
sahilshah379 committed Oct 16, 2023
1 parent 74d7b38 commit daea443
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ns_vfs/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
config = omegaconf.load_config_from_yaml(config_path)

config.VERSION_AND_PATH.ROOT_PATH = os.path.join(ROOT_DIR)
config.VERSION_AND_PATH.ARTIFACTS_PATH = os.path.join(ROOT_DIR, "store/nsvs_artifact")
config.VERSION_AND_PATH.ARTIFACTS_PATH = os.path.join(ROOT_DIR, "artifacts")
3 changes: 3 additions & 0 deletions ns_vfs/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ GROUNDING_DINO:

YOLO:
YOLO_CHECKPOINT_PATH:

YOLOX:
YOLOX_CHECKPOINT_PATH:
5 changes: 5 additions & 0 deletions ns_vfs/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,9 @@ def load_config():
"weights",
"yolov8n.pt",
)
config.YOLOX.YOLOX_CHECKPOINT_PATH = os.path.join(
config.VERSION_AND_PATH.ARTIFACTS_PATH,
"weights",
"yolox_x.pth",
)
return config
181 changes: 181 additions & 0 deletions ns_vfs/model/vision/yolox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import warnings

import os
import cv2
import torch
import numpy as np
import supervision as sv
from loguru import logger
from omegaconf import DictConfig
from ultralytics import YOLO
from yolox.data.data_augment import ValTransform
from yolox.data.datasets import COCO_CLASSES
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis


from ns_vfs.model.vision._base import ComputerVisionDetector

warnings.filterwarnings("ignore")


class YoloX(ComputerVisionDetector):
"""YoloX."""

def __init__(self, config: DictConfig, weight_path: str) -> None:
self.model = self.load_model(weight_path)
self._config = config
self._classes_reversed = {v: k for k, v in self.model.names.items()}

def load_model(self, weight_path) -> YOLO:
"""Load weight.
Args:
weight_path (str): Path to weight file.
Returns:
None
"""
exp = None
exp.test_conf = 0.25
exp.nmsthre = 0.45
exp.test_size = (640, 640)

model = exp.get_model()
model.eval()
ckpt = torch.load(weight_path, map_location="cpu")
model.load_state_dict(ckpt["model"])

return Predictor(model, exp)

def _parse_class_name(self, class_names: list[str]) -> list[str]:
"""Parse class name.
Args:
class_names (list[str]): List of class names.
Returns:
list[str]: List of class names.
"""
return [f"all {class_name}s" for class_name in class_names]

def detect(self, frame_img: np.ndarray, classes: list) -> any:
"""Detect object in frame.
Args:
frame_img (np.ndarray): Frame image.
classes (list[str]): List of class names.
Returns:
any: Detections.
"""
class_ids = [self._classes_reversed[c.replace("_", " ")] for c in classes]
detected_obj = self.model.predict(source=frame_img, classes=class_ids)

self._labels = []
for i in range(len(detected_obj[0].boxes)):
class_id = int(detected_obj[0].boxes.cls[i])
confidence = float(detected_obj[0].boxes.conf[i])
self._labels.append(
f"{detected_obj[0].names[class_id] if class_id is not None else None} {confidence:0.2f}"
)

self._detection = sv.Detections(xyxy=detected_obj[0].boxes.xyxy.cpu().detach().numpy())

self._confidence = detected_obj[0].boxes.conf.cpu().detach().numpy()

self._size = len(detected_obj[0].boxes)

return detected_obj

def get_confidence_score(self, frame_img: np.ndarray, true_label: str) -> any:
max_conf = 0
class_ids = [self._classes_reversed[c.replace("_", " ")] for c in [true_label]]
detected_obj = self.model.predict(source=frame_img, classes=class_ids)[0]
all_detected_object_list = detected_obj.boxes.cls
all_detected_object_confidence = detected_obj.boxes.conf

for i in range(len(all_detected_object_list)):
if all_detected_object_list[i] == class_ids[0]:
if all_detected_object_confidence[i] > max_conf:
max_conf = all_detected_object_confidence[i].cpu().item()
return max_conf


class Predictor(object):
def __init__(
self,
model,
exp,
cls_names=COCO_CLASSES,
decoder=None,
device="cpu",
fp16=False,
legacy=False,
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.preproc = ValTransform(legacy=legacy)

def inference(self, img):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None

height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img

ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
img_info["ratio"] = ratio

img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
if self.device == "gpu":
img = img.cuda()
if self.fp16:
img = img.half() # to FP16

with torch.no_grad():
t0 = time.time()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info

def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()

bboxes = output[:, 0:4]

# preprocessing: resize
bboxes /= ratio

cls = output[:, 6]
scores = output[:, 4] * output[:, 5]

vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res
12 changes: 9 additions & 3 deletions run_scripts/run_frame_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ns_vfs.frame_searcher import FrameSearcher
from ns_vfs.model.vision.grounding_dino import GroundingDino
from ns_vfs.model.vision.yolo import Yolo
from ns_vfs.model.vision.yolox import YoloX
from ns_vfs.processor.benchmark_video_processor import BenchmarkVideoFrameProcessor
from ns_vfs.processor.video_processor import (
VideoFrameProcessor,
Expand All @@ -14,14 +15,14 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cv_model", type=str, default="grounding_dino")
parser.add_argument("--cv_model", type=str, default="yolo")
parser.add_argument(
"--video_processor", type=str, default="regular_video", choices=["regular_video", "benchmark_video"]
)
parser.add_argument(
"--video_path",
type=str,
default="/opt/Neuro-Symbolic-Video-Frame-Search/artifacts/data/nyc_street/nyc_stree_53sec.mp4",
default="../VIRAT_S_050201_05_000890_000944.mp4",
)
parser.add_argument(
"--proposition_set",
Expand Down Expand Up @@ -79,11 +80,16 @@
weight_path=config.GROUNDING_DINO.GROUNDING_DINO_CHECKPOINT_PATH,
config_path=config.GROUNDING_DINO.GROUNDING_DINO_CONFIG_PATH,
)
else:
elif args.cv_model == "yolo":
cv_model = Yolo(
config=config.YOLO,
weight_path=config.YOLO.YOLO_CHECKPOINT_PATH,
)
else:
cv_model = YoloX(
config=config.YOLOX,
weight_path=config.YOLOX.YOLOX_CHECKPOINT_PATH,
)

video_automata_builder = VideotoAutomaton(
detector=cv_model,
Expand Down

0 comments on commit daea443

Please sign in to comment.