Skip to content

Commit

Permalink
Refactor frame handling: Read frames directly from video file instead…
Browse files Browse the repository at this point in the history
… of extracting them to disk
  • Loading branch information
healthonrails committed Sep 18, 2024
1 parent 2f138fa commit 3695000
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions annolid/segmentation/SAM/sam_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
# Enable CPU fallback for unsupported MPS ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import glob
import torch
import numpy as np
import cv2
import numpy as np
import torch
import glob

from annolid.utils.files import download_file
from annolid.annotation.label_processor import LabelProcessor
from annolid.gui.shape import MaskShape
from annolid.annotation.keypoints import save_labels
from annolid.utils.devices import get_device
from annolid.utils.videos import extract_frames_with_opencv
from annolid.annotation.keypoints import save_labels
from annolid.gui.shape import MaskShape
from annolid.annotation.label_processor import LabelProcessor
from annolid.utils.files import download_file


class SAM2VideoProcessor:
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(self, video_dir, id_to_labels,
self.epsilon_for_polygon = epsilon_for_polygon
self.frame_names = self._load_frame_names()
self.predictor = self._initialize_predictor()

self._handle_device_specific_settings()

def _initialize_predictor(self):
Expand Down Expand Up @@ -91,12 +90,34 @@ def _enable_cuda_optimizations(self):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def _normalize_video_path(self, path):
"""
Normalizes the video path to ensure it's a directory path.
If a file path is provided, it removes the file extension.
Args:
path (str): The input video path.
Returns:
str: The normalized video directory path.
"""
if os.path.isfile(path):
# If the path is a file, remove the extension to get the directory
return os.path.splitext(path)[0]
elif os.path.isdir(path):
# If the path is already a directory, return it as is
return path
else:
raise ValueError(f"Invalid path: {path}")


def _load_frame_names(self):
"""Loads and sorts JPEG frame names from the specified directory."""
try:
video_dir = self._normalize_video_path(self.video_dir)
frame_names = [
p for p in os.listdir(self.video_dir)
if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg", '.png']
]
frame_names.sort(key=lambda p: (os.path.splitext(p)[0]))
return frame_names
Expand All @@ -106,7 +127,8 @@ def _load_frame_names(self):

def get_frame_shape(self):
"""Returns the shape of the first frame in the video directory."""
first_frame_path = os.path.join(self.video_dir, self.frame_names[0])
video_dir = self._normalize_video_path(self.video_dir)
first_frame_path = os.path.join(video_dir, self.frame_names[0])
first_frame = cv2.imread(first_frame_path)
if first_frame is None:
raise ValueError(
Expand Down Expand Up @@ -193,7 +215,8 @@ def _propagate(self, inference_state):
"""Runs mask propagation and visualizes the results every few frames."""
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
mask_dict = {}
filename = os.path.join(self.video_dir, f'{out_frame_idx:09}.json')
video_dir = self._normalize_video_path(self.video_dir)
filename = os.path.join(video_dir, f'{out_frame_idx:09}.json')
for i, out_obj_id in enumerate(out_obj_ids):
_obj_mask = (out_mask_logits[i] > 0.0).cpu().numpy().squeeze()
mask_dict[str(out_obj_id)] = _obj_mask
Expand Down Expand Up @@ -239,8 +262,7 @@ def process_video(video_path,
model_config (str, optional): Path to the model configuration file.
epsilon_for_polygon (float, optional): Epsilon value for polygon approximation.
"""
# Extract frames from the video
video_dir = extract_frames_with_opencv(video_path)
video_dir = os.path.splitext(video_path)[0]

# Find all JSON annotation files in the directory
anno_jsons = glob.glob(os.path.join(video_dir, "*.json"))
Expand All @@ -264,15 +286,16 @@ def process_video(video_path,
label_processor = LabelProcessor(anno_json)

# Convert shapes to the custom annotations format
annotations = label_processor.convert_shapes_to_annotations(ann_frame_idx)
annotations = label_processor.convert_shapes_to_annotations(
ann_frame_idx)
all_annotations.extend(annotations)

# Update the mapping of object IDs to labels
id_to_labels.update(label_processor.get_id_to_labels())

# Initialize the analyzer
analyzer = SAM2VideoProcessor(
video_dir=video_dir,
video_dir=video_path,
id_to_labels=id_to_labels,
checkpoint_path=checkpoint_path,
model_config=model_config,
Expand Down

0 comments on commit 3695000

Please sign in to comment.