diff --git a/annolid/gui/app.py b/annolid/gui/app.py index 4adf342..b8400cf 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -54,6 +54,7 @@ import qimage2ndarray from annolid.gui.widgets.video_slider import VideoSlider, VideoSliderMark from annolid.gui.widgets.step_size_widget import StepSizeWidget +from annolid.gui.widgets.downsample_videos_dialog import VideoRescaleWidget from annolid.postprocessing.quality_control import pred_dict_to_labelme from annolid.annotation.timestamps import convert_frame_number_to_time from annolid.segmentation.SAM.edge_sam_bg import VideoProcessor @@ -348,6 +349,14 @@ def __init__(self, self.tr("Open Audio") ) + downsample_video = action( + self.tr("&Downsample Videos"), + self.downsample_videos, + None, + "Downsample Videos", + self.tr("Downsample Videos") + ) + step_size = QtWidgets.QWidgetAction(self) step_size.setIcon(QtGui.QIcon( str( @@ -500,6 +509,7 @@ def __init__(self, utils.addActions(self.menus.file, (tracks,)) utils.addActions(self.menus.file, (quality_control,)) utils.addActions(self.menus.file, (segment_cells,)) + utils.addActions(self.menus.file, (downsample_video,)) utils.addActions(self.menus.file, (advance_params,)) utils.addActions(self.menus.view, (glitter2,)) @@ -543,7 +553,8 @@ def __init__(self, self._selectAiModelComboBox.setCurrentIndex(model_index) self._selectAiModelComboBox.currentIndexChanged.connect( lambda: self.canvas.initializeAiModel( - name=self._selectAiModelComboBox.currentText() + name=self._selectAiModelComboBox.currentText(), + _custom_ai_models=self.custom_ai_model_names, ) if self.canvas.createMode in ["ai_polygon", "ai_mask"] else None @@ -553,8 +564,17 @@ def __init__(self, def _grounding_sam(self): self.toggleDrawMode(False, createMode="grounding_sam") - self.canvas.predictAiRectangle( - self.aiRectangle._aiRectanglePrompt.text()) + prompt_text = self.aiRectangle._aiRectanglePrompt.text().lower() + if len(prompt_text) < 1: + logger.info(f"Invalid text prompt {prompt_text}") + return + if prompt_text.startswith('flags:') and ',' in prompt_text: + flags = {k: False for k in prompt_text.replace( + 'flags:', '').split(',') or []} + self.loadFlags(flags) + else: + self.canvas.predictAiRectangle( + self.aiRectangle._aiRectanglePrompt.text()) def update_step_size(self, value): self.step_size = value @@ -563,6 +583,11 @@ def update_step_size(self, value): def flag_item_clicked(self, item): item_text = item.text() self.event_type = item_text + logger.info(f"Selected event {self.event_type}.") + + def downsample_videos(self): + video_downsample_widget = VideoRescaleWidget() + video_downsample_widget.exec_() def openAudio(self): if self.video_file: diff --git a/annolid/gui/widgets/canvas.py b/annolid/gui/widgets/canvas.py index b478a30..b053089 100644 --- a/annolid/gui/widgets/canvas.py +++ b/annolid/gui/widgets/canvas.py @@ -161,8 +161,9 @@ def createMode(self, value): self.sam_mask = MaskShape() self.current = None - def initializeAiModel(self, name): - if name not in [model.name for model in labelme.ai.MODELS]: + def initializeAiModel(self, name, _custom_ai_models=None): + if (name not in [model.name for model in labelme.ai.MODELS] + and name not in _custom_ai_models): logger.warning("Unsupported ai model: %s" % name) model = labelme.ai.MODELS[3] else: diff --git a/annolid/gui/widgets/downsample_videos_dialog.py b/annolid/gui/widgets/downsample_videos_dialog.py new file mode 100644 index 0000000..e71762f --- /dev/null +++ b/annolid/gui/widgets/downsample_videos_dialog.py @@ -0,0 +1,154 @@ +from annolid.utils.videos import (collect_video_metadata, + compress_and_rescale_video, + save_metadata_to_csv) +from qtpy.QtWidgets import (QApplication, QDialog, + QLabel, QPushButton, + QVBoxLayout, QFileDialog, + QCheckBox, QSlider, + QLineEdit, QMessageBox + ) +from qtpy.QtCore import Qt + + +class VideoRescaleWidget(QDialog): + """ + Widget for rescaling video files. + + Allows users to select input and output folders, specify a scale factor, + choose whether to rescale or downsample the video, and collect metadata. + """ + + def __init__(self): + super().__init__() + self.setWindowTitle('Video Rescaling') + + self.init_ui() + + def init_ui(self): + """ + Initialize the user interface. + """ + self.input_folder_label = QLabel('Input Folder:') + self.input_folder_button = QPushButton('Select Folder') + self.input_folder_button.clicked.connect(self.select_input_folder) + + self.output_folder_label = QLabel('Output Folder:') + self.output_folder_button = QPushButton('Select Folder') + self.output_folder_button.clicked.connect(self.select_output_folder) + + self.scale_factor_label = QLabel('Scale Factor:') + self.scale_factor_slider = QSlider(Qt.Horizontal) + self.scale_factor_slider.setMinimum(0) + self.scale_factor_slider.setMaximum(100) + self.scale_factor_slider.setValue(25) + self.scale_factor_slider.setTickInterval(25) + self.scale_factor_slider.setTickPosition(QSlider.TicksBelow) + self.scale_factor_slider.valueChanged.connect( + self.update_scale_factor_from_slider) + + self.scale_factor_text = QLineEdit() + self.scale_factor_text.setText("0.25") # Default scale factor + self.scale_factor_text.editingFinished.connect( + self.update_scale_factor_from_text) + + self.rescale_checkbox = QCheckBox('Rescale Video') + + self.collect_only_checkbox = QCheckBox('Collect Metadata Only') + + self.run_button = QPushButton('Run Rescaling') + self.run_button.clicked.connect(self.run_rescaling) + + layout = QVBoxLayout() + layout.addWidget(self.input_folder_label) + layout.addWidget(self.input_folder_button) + layout.addWidget(self.output_folder_label) + layout.addWidget(self.output_folder_button) + layout.addWidget(self.scale_factor_label) + layout.addWidget(self.scale_factor_slider) + layout.addWidget(self.scale_factor_text) + layout.addWidget(self.rescale_checkbox) + layout.addWidget(self.collect_only_checkbox) + layout.addWidget(self.run_button) + + self.setLayout(layout) + + def select_input_folder(self): + """ + Open a file dialog to select the input folder. + """ + folder = QFileDialog.getExistingDirectory(self, 'Select Input Folder') + if folder: + self.input_folder_label.setText(f'Input Folder: {folder}') + + def select_output_folder(self): + """ + Open a file dialog to select the output folder. + """ + folder = QFileDialog.getExistingDirectory(self, 'Select Output Folder') + if folder: + self.output_folder_label.setText(f'Output Folder: {folder}') + + def update_scale_factor_from_slider(self): + """ + Update the scale factor when the slider is dragged. + """ + scale_factor = self.scale_factor_slider.value() / 100 + self.scale_factor_text.setText(str(scale_factor)) + + def update_scale_factor_from_text(self): + """ + Update the slider when the text field is edited. + """ + scale_factor_text = self.scale_factor_text.text() + try: + scale_factor = float(scale_factor_text) + if 0.0 <= scale_factor <= 1.0: + self.scale_factor_slider.setValue(int(scale_factor * 100)) + else: + self.scale_factor_text.setText("Invalid Value") + except ValueError: + self.scale_factor_text.setText("Invalid Value") + + def run_rescaling(self): + """ + Run the rescaling process based on user inputs. + """ + # Disable the button during processing + self.run_button.setEnabled(False) + self.run_button.setText('Processing...') + + input_folder = self.input_folder_label.text().split(': ')[-1] + output_folder = self.output_folder_label.text().split(': ')[-1] + scale_factor = float(self.scale_factor_text.text()) + rescale = self.rescale_checkbox.isChecked() + collect_only = self.collect_only_checkbox.isChecked() + + if collect_only: + metadata = collect_video_metadata(input_folder) + output_csv, _ = QFileDialog.getSaveFileName( + self, 'Select Output CSV') + if output_csv: + save_metadata_to_csv(metadata, output_csv) + QMessageBox.information( + self, 'Done', 'Metadata collection is done.') + elif rescale: + compress_and_rescale_video( + input_folder, output_folder, scale_factor) + if output_folder: + metadata = collect_video_metadata(output_folder) + output_csv, _ = QFileDialog.getSaveFileName( + self, 'Select Output CSV') + if output_csv: + save_metadata_to_csv(metadata, output_csv) + QMessageBox.information(self, 'Done', 'Rescaling is done.') + + # Enable the button and change its text back to original + self.run_button.setEnabled(True) + self.run_button.setText('Run Rescaling') + + +if __name__ == '__main__': + app = QApplication([]) + widget = VideoRescaleWidget() + widget.exec_() + app.exec_() diff --git a/annolid/gui/widgets/video_slider.py b/annolid/gui/widgets/video_slider.py index 85ff733..cb20d18 100644 --- a/annolid/gui/widgets/video_slider.py +++ b/annolid/gui/widgets/video_slider.py @@ -238,8 +238,31 @@ def __init__( self.headerSeries = dict() self._draw_header() + # Adding QLineEdit for input value + self.input_value = QtWidgets.QLineEdit(str(self.value()), self) + self.input_value.setFixedWidth(60) + self.input_value.setAlignment(QtCore.Qt.AlignCenter) + self.input_value.editingFinished.connect(self.updateValueFromInput) + self.input_value.move(2, 2) + # Methods to match API for QSlider + def updateValueFromInput(self): + # Get the input text + input_text = self.input_value.text() + try: + # Try converting the input to float + input_val = float(input_text) + # Check if the input value is within the range + if self._val_min <= input_val <= self._val_max: + self.setValue(input_val) + else: + # Reset the input text if it's out of range + self.input_value.setText(str(self._val_main)) + except ValueError: + # Reset the input text if it's not a valid float + self.input_value.setText(str(self._val_main)) + def value(self) -> float: """Returns value of slider.""" return self._val_main @@ -250,6 +273,9 @@ def setValue(self, val: float) -> float: x = self._toPos(val) self.handle.setPos(x, 0) self.ensureVisible(x, 0, self._handle_width, 0, 3, 0) + if hasattr(self, 'input_value'): + # Update input text value with slider's current value + self.input_value.setText(str(self._val_main)) def setMinimum(self, min: float) -> float: """Sets minimum value for slider.""" @@ -1064,6 +1090,9 @@ def done(x, y): self.mouseMoved.emit(scenePos.x(), scenePos.y()) self.mousePressed.emit(scenePos.x(), scenePos.y()) + # Update input text value with slider's current value + self.input_value.setText(str(self._val_main)) + def mouseMoveEvent(self, event): """Override method to emit mouseMoved signal on drag.""" scenePos = self.mapToScene(event.pos()) diff --git a/annolid/postprocessing/glitter.py b/annolid/postprocessing/glitter.py index c4c9926..a3b7aee 100644 --- a/annolid/postprocessing/glitter.py +++ b/annolid/postprocessing/glitter.py @@ -119,7 +119,7 @@ def animal_in_zone(animal_mask, ): if animal_mask is not None and zone_mask is not None: overlap = mask_util.iou([animal_mask], [zone_mask], [ - False, False]).flatten()[0] + 0]).flatten()[0] return overlap > threshold else: return False @@ -148,7 +148,7 @@ def keypoint_in_body_mask( if keypoint_seg and body_seg: overlap = mask_util.iou([body_seg], [keypoint_seg], [ - False, False]).flatten()[0] + 0]).flatten()[0] return overlap > 0 else: return False @@ -172,7 +172,7 @@ def left_right_interact(fn, left_instance]['segmentation'].values[0] left_instance_seg = ast.literal_eval(left_instance_seg) left_interact = mask_util.iou([left_instance_seg], [subject_instance_seg], [ - False, False]).flatten()[0] + 0]).flatten()[0] except IndexError: left_interact = 0.0 try: @@ -180,7 +180,7 @@ def left_right_interact(fn, right_instance]['segmentation'].values[0] right_instance_seg = ast.literal_eval(right_instance_seg) right_interact = mask_util.iou([right_instance_seg], [subject_instance_seg], [ - False, False]).flatten()[0] + 0]).flatten()[0] except IndexError: right_interact = 0.0 diff --git a/annolid/postprocessing/tracking_results_analyzer.py b/annolid/postprocessing/tracking_results_analyzer.py new file mode 100644 index 0000000..ce8cebb --- /dev/null +++ b/annolid/postprocessing/tracking_results_analyzer.py @@ -0,0 +1,183 @@ +import pandas as pd +import json +import itertools +import matplotlib.pyplot as plt +from shapely.geometry import Point +from shapely.geometry.polygon import Polygon + + +class TrackingResultsAnalyzer: + """ + A class to analyze tracking results + and visualize time spent in zones for instances. + + Attributes: + video_name (str): The name of the video. + zone_file (str): The path to the JSON file containing zone information. + tracking_csv (str): The path to the tracking CSV file. + tracked_csv (str): The path to the tracked CSV file. + tracking_df (DataFrame): DataFrame containing tracking data. + tracked_df (DataFrame): DataFrame containing tracked data. + merged_df (DataFrame): DataFrame containing merged tracking and tracked data. + distances_df (DataFrame): DataFrame containing distances between instances. + zone_data (dict): Dictionary containing zone information loaded from the zone JSON file. + """ + + def __init__(self, video_name, zone_file): + """ + Initialize the TrackingResultsAnalyzer. + + Args: + video_name (str): The name of the video. + zone_file (str): The path to the JSON file containing zone information. + """ + self.video_name = video_name + self.tracking_csv = f"{video_name}_tracking.csv" + self.tracked_csv = f"{video_name}_tracked.csv" + self.zone_file = zone_file + + def read_csv_files(self): + """Read tracking and tracked CSV files into DataFrames.""" + self.tracking_df = pd.read_csv(self.tracking_csv) + self.tracked_df = pd.read_csv(self.tracked_csv) + + def merge_and_calculate_distance(self): + """Merge tracking and tracked dataframes based on + frame number and instance name, and calculate distances.""" + self.read_csv_files() + + # Merge DataFrames based on frame number and instance name + self.merged_df = pd.merge(self.tracking_df, self.tracked_df, + on=['frame_number', 'instance_name'], + suffixes=('_tracking', '_tracked')) + + # Calculate distance between different instances in the same frame + distances = [] + for frame_number, frame_group in self.merged_df.groupby('frame_number'): + instances_in_frame = frame_group['instance_name'].unique() + instance_combinations = itertools.combinations( + instances_in_frame, 2) + for instance_combination in instance_combinations: + instance1 = instance_combination[0] + instance2 = instance_combination[1] + instance1_data = frame_group[frame_group['instance_name'] == instance1] + instance2_data = frame_group[frame_group['instance_name'] == instance2] + for _, row1 in instance1_data.iterrows(): + for _, row2 in instance2_data.iterrows(): + distance = self.calculate_distance(row1['cx_tracking'], + row1['cy_tracking'], + row2['cx_tracked'], + row2['cy_tracked']) + distances.append({ + 'frame_number': frame_number, + 'instance_name_1': instance1, + 'instance_name_2': instance2, + 'distance': distance + }) + + self.distances_df = pd.DataFrame(distances) + + def calculate_distance(self, x1, y1, x2, y2): + """ + Calculate the Euclidean distance between two points. + + Args: + x1 (float): X-coordinate of the first point. + y1 (float): Y-coordinate of the first point. + x2 (float): X-coordinate of the second point. + y2 (float): Y-coordinate of the second point. + + Returns: + float: The Euclidean distance between the two points. + """ + return ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5 + + def load_zone_json(self): + """Load zone information from the JSON file.""" + with open(self.zone_file, 'r') as f: + self.zone_data = json.load(f) + + def determine_time_in_zone(self, instance_label): + """ + Determine the time spent by an instance in each zone. + + Args: + instance_label (str): The label of the instance. + + Returns: + dict: A dictionary containing the time spent by the instance in each zone. + """ + self.load_zone_json() + + # Filter merged DataFrame for given instance + instance_df = self.merged_df[self.merged_df['instance_name'] + == instance_label] + + zone_time_dict = {shape['label'] + : 0 for shape in self.zone_data['shapes']} + + for shape in self.zone_data['shapes']: + zone_label = shape['label'] + zone_time = 0 + # Check if instance points are within the zone + for _, row in instance_df.iterrows(): + if self.is_point_in_polygon([row['cx_tracked'], + row['cy_tracked']], shape['points']): + zone_time += 1 + + zone_time_dict[zone_label] = zone_time + + return zone_time_dict + + def is_point_in_polygon(self, point, polygon_points): + """ + Check if a point is inside a polygon. + + Args: + point (tuple): The coordinates of the point (x, y). + polygon_points (list): List of tuples representing the polygon vertices. + + Returns: + bool: True if the point is inside the polygon, False otherwise. + """ + # Create a Shapely Point object + point = Point(point[0], point[1]) + + # Create a Shapely Polygon object + polygon = Polygon(polygon_points) + + # Check if the point is within the polygon + return polygon.contains(point) + + def plot_time_in_zones(self, instance_label): + """ + Plot the time spent by an instance in each zone. + + Args: + instance_label (str): The label of the instance. + """ + zone_time_dict = self.determine_time_in_zone(instance_label) + + plt.bar(zone_time_dict.keys(), zone_time_dict.values()) + plt.xlabel('Zone') + plt.ylabel('Time (frames)') + plt.title(f'Time Spent in Each Zone for {instance_label}') + plt.show() + + +if __name__ == '__main__': + import argparse + + # Parse command-line arguments + parser = argparse.ArgumentParser(description='Track results analyzer') + parser.add_argument('video_name', type=str, help='Name of the video') + parser.add_argument('zone_file', type=str, + help='Path to the zone JSON file') + args = parser.parse_args() + + # Create and run the analyzer + analyzer = TrackingResultsAnalyzer(args.video_name, args.zone_file) + analyzer.merge_and_calculate_distance() + time_in_zone_mouse = analyzer.determine_time_in_zone("mouse_0") + print("Time in zone for mouse:", time_in_zone_mouse) + analyzer.plot_time_in_zones("mouse_0") diff --git a/annolid/utils/shapes.py b/annolid/utils/shapes.py index 8538c77..394b777 100644 --- a/annolid/utils/shapes.py +++ b/annolid/utils/shapes.py @@ -49,7 +49,7 @@ def masks_to_bboxes(masks): where = np.argwhere(mask) if where.size > 0: # Check if where array is not empty (y1, x1), (y2, x2) = where.min(0), where.max(0) + 1 - bboxes.append((y1, x1, y2, x2)) + bboxes.append((x1, y1, x2, y2)) bboxes = np.asarray(bboxes, dtype=np.float32) return bboxes diff --git a/annolid/utils/videos.py b/annolid/utils/videos.py index a7ae293..e79a4a0 100644 --- a/annolid/utils/videos.py +++ b/annolid/utils/videos.py @@ -115,7 +115,9 @@ def compress_and_rescale_video(input_folder, output_folder, scale_factor): for video_file in video_files: input_path = os.path.join(input_folder, video_file) output_path = os.path.join(output_folder, video_file) - + # Update extension to .mp4 + root, extension = os.path.splitext(output_path) + output_path = root + '.mp4' cmd = [ 'ffmpeg', '-i', input_path, '-vf', f'scale=iw*{scale_factor}:ih*{scale_factor}', @@ -176,7 +178,7 @@ def main(args): parser.add_argument('--output_csv', type=str, help='Output CSV file path for metadata.') parser.add_argument('--scale_factor', type=float, - default=1.0, help='Scale factor for resizing videos.') + default=0.5, help='Scale factor for resizing videos.') parser.add_argument('--collect_only', action='store_true', help='Collect metadata only, do not compress and rescale.')