Skip to content

Commit

Permalink
[stats_manager] Simplify metric key registration
Browse files Browse the repository at this point in the history
Fixes #334
  • Loading branch information
Breakthrough committed Jan 27, 2024
1 parent 14daf89 commit e99a1bc
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 96 deletions.
9 changes: 1 addition & 8 deletions scenedetect/scene_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,14 +647,7 @@ def add_detector(self, detector: SceneDetector) -> None:

detector.stats_manager = self._stats_manager
if self._stats_manager is not None:
try:
self._stats_manager.register_metrics(detector.get_metrics())
except FrameMetricRegistered:
# Allow multiple detection algorithms of the same type to be added
# by suppressing any FrameMetricRegistered exceptions due to attempts
# to re-register the same frame metric keys.
# TODO(#334): Fix this, this should not be part of regular control flow.
pass
self._stats_manager.register_metrics(detector.get_metrics())

if not issubclass(type(detector), SparseSceneDetector):
self._detector_list.append(detector)
Expand Down
83 changes: 32 additions & 51 deletions scenedetect/stats_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import csv
from logging import getLogger
import typing as ty
# TODO: Replace below imports with `ty.` prefix.
from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Union
import os.path

Expand All @@ -47,25 +49,13 @@


class FrameMetricRegistered(Exception):
""" Raised when attempting to register a frame metric key which has
already been registered. """

def __init__(self,
metric_key: str,
message: str = "Attempted to re-register frame metric key."):
super().__init__(message)
self.metric_key = metric_key
"""[DEPRECATED - DO NOT USE] No longer used."""
pass


class FrameMetricNotRegistered(Exception):
""" Raised when attempting to call get_metrics(...)/set_metrics(...) with a
frame metric that does not exist, or has not been registered. """

def __init__(self,
metric_key: str,
message: str = "Attempted to get/set frame metrics for unregistered metric key."):
super().__init__(message)
self.metric_key = metric_key
"""[DEPRECATED - DO NOT USE] No longer used."""
pass


class StatsFileCorrupt(Exception):
Expand Down Expand Up @@ -107,30 +97,21 @@ def __init__(self, base_timecode: FrameTimecode = None):
# Frame metrics is a dict of frame (int): metric_dict (Dict[str, float])
# of each frame metric key and the value it represents (usually float).
self._frame_metrics: Dict[FrameTimecode, Dict[str, float]] = dict()
self._registered_metrics: Set[str] = set() # Set of frame metric keys.
self._loaded_metrics: Set[str] = set() # Metric keys loaded from stats file.
self._metric_keys: Set[str] = set()
self._metrics_updated: bool = False # Flag indicating if metrics require saving.
self._base_timecode: Optional[FrameTimecode] = base_timecode # Used for timing calculations.

def register_metrics(self, metric_keys: Iterable[str]) -> None:
"""Register a list of metric keys that will be used by the detector.
Used to ensure that multiple detector keys don't overlap.
@property
def metric_keys(self) -> ty.Iterable[str]:
return self._metric_keys

Raises:
FrameMetricRegistered: A particular metric_key has already been registered/added
to the StatsManager. Only if the StatsManager is being used for read-only
access (i.e. all frames in the video have already been processed for the given
metric_key in the exception) is this behavior desirable.
"""
for metric_key in metric_keys:
if metric_key not in self._registered_metrics:
self._registered_metrics.add(metric_key)
else:
raise FrameMetricRegistered(metric_key)
def register_metrics(self, metric_keys: Iterable[str]) -> None:
"""Register a list of metric keys that will be used by the detector."""
self._metric_keys = self._metric_keys.union(set(metric_keys))

# TODO(v1.0): Change frame_number to a FrameTimecode now that it is just a hash and will
# be required for VFR support.
# be required for VFR support. This API is also really difficult to use, this type should just
# function like a dictionary.
def get_metrics(self, frame_number: int, metric_keys: Iterable[str]) -> List[Any]:
"""Return the requested statistics/metrics for a given frame.
Expand Down Expand Up @@ -189,16 +170,12 @@ def save_to_csv(self,
"""
# TODO(v0.7): Replace with DeprecationWarning that `base_timecode` will be removed in v0.8.
if base_timecode is not None:
logger.error('base_timecode is deprecated.')
logger.error('base_timecode is deprecated and has no effect.')

# Ensure we need to write to the file, and that we have data to do so with.
if not ((self.is_save_required() or force_save) and self._registered_metrics
and self._frame_metrics):
logger.info("No metrics to save.")
if not (force_save or self.is_save_required()):
logger.info("No metrics to write.")
return

assert self._base_timecode is not None

# If we get a path instead of an open file handle, recursively call ourselves
# again but with file handle instead of path.
if isinstance(csv_file, (str, bytes)):
Expand All @@ -207,7 +184,7 @@ def save_to_csv(self,
return

csv_writer = csv.writer(csv_file, lineterminator='\n')
metric_keys = sorted(list(self._registered_metrics.union(self._loaded_metrics)))
metric_keys = sorted(list(self._metric_keys))
csv_writer.writerow([COLUMN_NAME_FRAME_NUMBER, COLUMN_NAME_TIMECODE] + metric_keys)
frame_keys = sorted(self._frame_metrics.keys())
logger.info("Writing %d frames to CSV...", len(frame_keys))
Expand All @@ -234,7 +211,8 @@ def valid_header(row: List[str]) -> bool:
return False
return True

# TODO(v1.0): Remove.
# TODO(v1.0): Create a replacement for a calculation cache that functions like load_from_csv
# did, but is better integrated with detectors for cached calculations instead of statistics.
def load_from_csv(self, csv_file: Union[str, bytes, TextIO]) -> Optional[int]:
"""[DEPRECATED] DO NOT USE
Expand Down Expand Up @@ -285,29 +263,32 @@ def load_from_csv(self, csv_file: Union[str, bytes, TextIO]) -> Optional[int]:
num_metrics = num_cols - 2
if not num_metrics > 0:
raise StatsFileCorrupt('No metrics defined in CSV file.')
self._loaded_metrics = row[2:]
loaded_metrics = list(row[2:])
num_frames = 0
for row in csv_reader:
metric_dict = {}
if not len(row) == num_cols:
raise StatsFileCorrupt('Wrong number of columns detected in stats file row.')
for i, metric_str in enumerate(row[2:]):
if metric_str and metric_str != 'None':
try:
metric_dict[self._loaded_metrics[i]] = float(metric_str)
except ValueError:
raise StatsFileCorrupt('Corrupted value in stats file: %s' %
metric_str) from ValueError
frame_number = int(row[0])
# Switch from 1-based to 0-based frame numbers.
if frame_number > 0:
frame_number -= 1
self.set_metrics(frame_number, metric_dict)
for i, metric in enumerate(row[2:]):
if metric and metric != 'None':
try:
self._set_metric(frame_number, loaded_metrics[i], float(metric))
except ValueError:
raise StatsFileCorrupt('Corrupted value in stats file: %s' %
metric) from ValueError
num_frames += 1
self._metric_keys = self._metric_keys.union(set(loaded_metrics))
logger.info('Loaded %d metrics for %d frames.', num_metrics, num_frames)
self._metrics_updated = False
return num_frames

# TODO: Get rid of these functions and simplify the implementation of this class.

def _get_metric(self, frame_number: int, metric_key: str) -> Optional[Any]:
if self._metric_exists(frame_number, metric_key):
return self._frame_metrics[frame_number][metric_key]
Expand Down
4 changes: 1 addition & 3 deletions tests/test_backwards_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def test_backwards_compatibility_with_stats(test_video_file: str):
"""Runs equivalent code to `tests/api_test.py` from v0.5 twice to also
exercise loading a statsfile from disk."""
stats_file_path = test_video_file + '.csv'
try:
if os.path.exists(stats_file_path):
os.remove(stats_file_path)
except FileNotFoundError:
pass
scenes = validate_backwards_compatibility(test_video_file, stats_file_path)
assert scenes
assert os.path.exists(stats_file_path)
Expand Down
45 changes: 12 additions & 33 deletions tests/test_stats_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@
from scenedetect.detectors import ContentDetector

from scenedetect.stats_manager import StatsManager
from scenedetect.stats_manager import FrameMetricRegistered
from scenedetect.stats_manager import StatsFileCorrupt

from scenedetect.stats_manager import COLUMN_NAME_FRAME_NUMBER
from scenedetect.stats_manager import COLUMN_NAME_TIMECODE

# TODO(v1.0): Need to add test case which raises scenedetect.stats_manager.FrameMetricNotRegistered.

# TODO(v1.0): use https://docs.pytest.org/en/6.2.x/tmpdir.html
TEST_STATS_FILES = ['TEST_STATS_FILE'] * 4
TEST_STATS_FILES = [
Expand Down Expand Up @@ -77,8 +74,6 @@ def test_metrics():
stats.register_metrics(metric_keys)

assert not stats.is_save_required()
with pytest.raises(FrameMetricRegistered):
stats.register_metrics(metric_keys)

assert not stats.metrics_exist(frame_key, metric_keys)
assert stats.get_metrics(frame_key, metric_keys) == [None] * len(metric_keys)
Expand All @@ -101,28 +96,13 @@ def test_detector_metrics(test_video_file):
video = VideoStreamCv2(test_video_file)
stats_manager = StatsManager()
scene_manager = SceneManager(stats_manager)

assert not stats_manager._registered_metrics
scene_manager.add_detector(ContentDetector())
# add_detector should trigger register_metrics in the StatsManager.
assert stats_manager._registered_metrics

video_fps = video.frame_rate
duration = FrameTimecode('00:00:20', video_fps)

duration = FrameTimecode('00:00:05', video_fps)
scene_manager.auto_downscale = True
scene_manager.detect_scenes(video=video, duration=duration)

# Check that metrics were written to the StatsManager.
assert stats_manager._frame_metrics
frame_key = min(stats_manager._frame_metrics.keys())
assert stats_manager._frame_metrics[frame_key]
assert stats_manager.metrics_exist(frame_key, list(stats_manager._registered_metrics))

# Since we only added 1 detector, the number of metrics from get_metrics
# should equal the number of metric keys in _registered_metrics.
assert len(stats_manager.get_metrics(frame_key, list(
stats_manager._registered_metrics))) == len(stats_manager._registered_metrics)
assert stats_manager.get_metrics(0, ContentDetector.METRIC_KEYS)


def test_load_empty_stats():
Expand Down Expand Up @@ -178,27 +158,26 @@ def test_save_load_from_video(test_video_file):
scene_manager.add_detector(ContentDetector())

video_fps = video.frame_rate
duration = FrameTimecode('00:00:20', video_fps)
duration = FrameTimecode('00:00:05', video_fps)

scene_manager.auto_downscale = True
scene_manager.detect_scenes(video, duration=duration)

stats_manager.save_to_csv(csv_file=TEST_STATS_FILES[0])

metrics = stats_manager.metric_keys

stats_manager_new = StatsManager()

stats_manager_new.load_from_csv(TEST_STATS_FILES[0])

# Choose the first available frame key and compare all metrics in both.
frame_key = min(stats_manager._frame_metrics.keys())
metric_keys = list(stats_manager._registered_metrics)

assert stats_manager.metrics_exist(frame_key, metric_keys)
orig_metrics = stats_manager.get_metrics(frame_key, metric_keys)
new_metrics = stats_manager_new.get_metrics(frame_key, metric_keys)

for i, metric_val in enumerate(orig_metrics):
assert metric_val == pytest.approx(new_metrics[i])
# Compare the first 5 frames. Frame 0 won't have any metrics for this detector.
for frame in range(1, 5 + 1):
assert stats_manager.metrics_exist(frame, metrics)
orig_metrics = stats_manager.get_metrics(frame, metrics)
new_metrics = stats_manager_new.get_metrics(frame, metrics)
for i, metric_val in enumerate(orig_metrics):
assert metric_val == pytest.approx(new_metrics[i])


def test_load_corrupt_stats():
Expand Down
7 changes: 6 additions & 1 deletion website/pages/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ Releases
- [bugfix] Fix `AttributeError` thrown when accessing `aspect_ratio` on certain videos using `VideoStreamAv` [#355](https://github.com/Breakthrough/PySceneDetect/issues/355)
- [bugfix] Fix circular imports due to partially initialized module for some development environments [#350](https://github.com/Breakthrough/PySceneDetect/issues/350)
- [feature] Add `output_dir` argument to split_video_* functions to customize output directory [#298](https://github.com/Breakthrough/PySceneDetect/issues/298)
- [feature] Add `formatter` argument to split_video_ffmpeg to customize filename generation [#359](https://github.com/Breakthrough/PySceneDetect/issues/359)
- [feature] Add `formatter` argument to split_video_ffmpeg to customize filename generation [#359](https://github.com/
Breakthrough/PySceneDetect/issues/359)
- [improvement] `scenedetect.stats_manager` module improvements:
- The `StatsManager.register_metrics()` method no longer throws any exceptions
- Add `StatsManager.metric_keys` property to query registered metric keys
- Deprecate `FrameMetricRegistered` and `FrameMetricNotRegistered` exceptions (no longer used)


### 0.6.2 (July 23, 2023)
Expand Down

0 comments on commit e99a1bc

Please sign in to comment.