Skip to content

Commit

Permalink
Fix core utils (#1967)
Browse files Browse the repository at this point in the history
* Empty-Commit

* Cleanup

* Updated to_tuple

* Fix in LabelEncoder

* Do not encode numerical labels

* Fix in encode labels

* Fix for mixed labels

* Cleanup
  • Loading branch information
ternaus authored Oct 6, 2024
1 parent a4e50e2 commit b8648ff
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 80 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-ast
Expand Down Expand Up @@ -51,7 +51,7 @@ repos:
files: setup.py
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.8
rev: v0.6.9
hooks:
# Run the linter.
- id: ruff
Expand Down
4 changes: 2 additions & 2 deletions albumentations/augmentations/blur/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def process_blur_limit(value: ScaleIntType, info: ValidationInfo, min_value: flo
for v in result:
if v != 0 and v % 2 != 1:
raise ValueError(f"Blur limit must be 0 or odd. Got: {result}")
return cast(Tuple[int, int], result)
return result


class BlurInitSchema(BaseTransformInitSchema):
Expand Down Expand Up @@ -171,7 +171,7 @@ class InitSchema(BaseTransformInitSchema):

@model_validator(mode="after")
def process_blur(self) -> Self:
self.blur_limit = cast(Tuple[int, int], to_tuple(self.blur_limit, 3))
self.blur_limit = to_tuple(self.blur_limit, 3)

if self.allow_shifted and isinstance(self.blur_limit, tuple) and any(x % 2 != 1 for x in self.blur_limit):
raise ValueError(f"Blur limit must be odd when centered=True. Got: {self.blur_limit}")
Expand Down
4 changes: 2 additions & 2 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,9 +2042,9 @@ class InitSchema(BaseTransformInitSchema):
@classmethod
def validate_num_bits(cls, num_bits: Any) -> tuple[int, int] | list[tuple[int, int]]:
if isinstance(num_bits, int):
return cast(Tuple[int, int], to_tuple(num_bits, num_bits))
return to_tuple(num_bits, num_bits)
if isinstance(num_bits, Sequence) and len(num_bits) == NUM_BITS_ARRAY_LENGTH:
return [cast(Tuple[int, int], to_tuple(i, 0)) for i in num_bits]
return [to_tuple(i, 0) for i in num_bits]
return cast(Tuple[int, int], to_tuple(num_bits, 0))

def __init__(
Expand Down
14 changes: 11 additions & 3 deletions albumentations/core/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from typing import Tuple
from typing import Tuple, overload

import cv2
from pydantic import Field
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated

from albumentations.core.types import NumericType, ScalarType, ScaleType
from albumentations.core.types import NumericType, ScalarType, ScaleFloatType, ScaleIntType, ScaleType
from albumentations.core.utils import to_tuple

valid_interpolations = {
Expand Down Expand Up @@ -68,7 +68,15 @@ def float2int(value: tuple[float, float]) -> tuple[int, int]:
NonNegativeIntRangeType = Annotated[ScaleType, AfterValidator(process_non_negative_range), AfterValidator(float2int)]


def create_symmetric_range(value: ScaleType) -> tuple[float, float]:
@overload
def create_symmetric_range(value: ScaleIntType) -> tuple[int, int]: ...


@overload
def create_symmetric_range(value: ScaleFloatType) -> tuple[float, float]: ...


def create_symmetric_range(value: ScaleType) -> tuple[int, int] | tuple[float, float]:
return to_tuple(value)


Expand Down
1 change: 1 addition & 0 deletions albumentations/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

ScaleIntType = Union[int, Tuple[int, int]]
ScaleFloatType = Union[float, Tuple[float, float]]

ScaleType = Union[ScaleIntType, ScaleFloatType]

NumType = Union[ScalarType, np.ndarray]
Expand Down
Loading

0 comments on commit b8648ff

Please sign in to comment.