Skip to content

Commit

Permalink
Fix core/keypoints_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet committed Mar 14, 2024
1 parent 86ed8c4 commit fea808f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 45 deletions.
75 changes: 33 additions & 42 deletions albumentations/core/keypoints_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import math
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

from .transforms_interface import KeypointsArray, KeypointsInternalType, KeypointType
from .types import KeypointsArray, KeypointsInternalType, KeypointType
from .utils import DataProcessor, InternalDtype, Params, ensure_internal_format

__all__ = [
Expand Down Expand Up @@ -93,23 +92,24 @@ class KeypointParams(Params):
angle_in_degrees (bool): angle in degrees or radians in 'xya', 'xyas', 'xysa' keypoints
check_each_transform (bool): if `True`, then keypoints will be checked after each dual transform.
Default: `True`
"""

def __init__(
self,
format: str, # skipcq: PYL-W0622
format: str,
label_fields: Optional[Sequence[str]] = None,
remove_invisible: bool = True,
angle_in_degrees: bool = True,
check_each_transform: bool = True,
):
super(KeypointParams, self).__init__(format, label_fields)
super().__init__(format, label_fields)
self.remove_invisible = remove_invisible
self.angle_in_degrees = angle_in_degrees
self.check_each_transform = check_each_transform

def _to_dict(self) -> Dict[str, Any]:
data = super(KeypointParams, self)._to_dict()
def to_dict_private(self) -> Dict[str, Any]:
data = super().to_dict_private()
data.update(
{
"remove_invisible": self.remove_invisible,
Expand All @@ -130,7 +130,6 @@ def get_class_fullname(cls) -> str:

class KeypointsProcessor(DataProcessor):
def __init__(self, params: KeypointParams, additional_targets: Optional[Dict[str, str]] = None):
assert isinstance(params, KeypointParams)
super().__init__(params, additional_targets)

def convert_to_internal_type(self, data):
Expand All @@ -157,35 +156,25 @@ def default_data_name(self) -> str:
return "keypoints"

def ensure_data_valid(self, data: Dict[str, Any]) -> None:
if self.params.label_fields:
if not all(i in data for i in self.params.label_fields):
raise ValueError(
"Your 'label_fields' are not valid - them must have same names as params in "
"'keypoint_params' dict"
)

def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
# IAA-based augmentations supports only transformation of xy keypoints.
# If your keypoints formats is other than 'xy' we emit warning to let user
# be aware that angle and size will not be modified.

try:
from albumentations.imgaug.transforms import DualIAATransform
except ImportError:
# imgaug is not installed so we skip imgaug checks.
return

if self.params.format is not None and self.params.format != "xy":
for transform in transforms:
if isinstance(transform, DualIAATransform):
warnings.warn(
f"{transform.__class__.__name__} transformation supports only 'xy' keypoints "
f"augmentation. You have '{self.params.format}' keypoints format. Scale "
"and angle WILL NOT BE transformed."
)
break

def filter(self, data: KeypointsArray, rows: int, cols: int, target_name: str):
if self.params.label_fields and not all(i in data for i in self.params.label_fields):
msg = "Your 'label_fields' are not valid - them must have same names as params in " "'keypoint_params' dict"
raise ValueError(msg)

def filter(self, data: KeypointsArray, rows: int, cols: int) -> KeypointsArray:
"""The function filters a sequence of data based on the number of rows and columns, and returns a
sequence of keypoints.
Args:
data: The `data` parameter is a sequence of sequences. Each inner sequence represents a
set of keypoints
rows: The `rows` parameter represents the number of rows in the data matrix. It specifies
the number of rows that will be used for filtering the keypoints
cols: The parameter "cols" represents the number of columns in the grid that the
keypoints will be filtered on
Returns:
KeypointsArray, a sequence of KeypointType objects.
"""
self.params: KeypointParams
data = filter_keypoints(data, rows, cols, remove_invisible=self.params.remove_invisible)
return data
Expand All @@ -194,23 +183,25 @@ def check(self, data: KeypointsArray, rows: int, cols: int) -> None:
check_keypoints(data, rows, cols)

def convert_from_albumentations(self, data: KeypointsArray, rows: int, cols: int):
params = self.params
return convert_keypoints_from_albumentations(
data,
self.params.format,
params.format,
rows,
cols,
check_validity=self.params.remove_invisible,
angle_in_degrees=self.params.angle_in_degrees,
check_validity=params.remove_invisible,
angle_in_degrees=params.angle_in_degrees,
)

def convert_to_albumentations(self, data: KeypointsArray, rows: int, cols: int):
params = self.params
return convert_keypoints_to_albumentations(
data,
self.params.format,
params.format,
rows,
cols,
check_validity=self.params.remove_invisible,
angle_in_degrees=self.params.angle_in_degrees,
check_validity=params.remove_invisible,
angle_in_degrees=params.angle_in_degrees,
)


Expand Down
3 changes: 0 additions & 3 deletions albumentations/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ def default_data_name(self) -> str:
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
pass

def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
pass

@abstractmethod
def convert_to_internal_type(self, data: Any) -> InternalDtype: # type: ignore[type-var]
raise NotImplementedError
Expand Down

0 comments on commit fea808f

Please sign in to comment.