Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Template transform moved to Domain Adaptation #1982

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions albumentations/augmentations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from .blur.transforms import *
from .crops.functional import *
from .crops.transforms import *
from .domain_adaptation import *
from .domain_adaptation_functional import *
from .domain_adaptation.functional import *
from .domain_adaptation.transforms import *
from .dropout.channel_dropout import *
from .dropout.coarse_dropout import *
from .dropout.functional import *
Expand Down
2 changes: 2 additions & 0 deletions albumentations/augmentations/domain_adaptation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .functional import *
from .transforms import *
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@

import cv2
import numpy as np
from pydantic import AfterValidator, field_validator
from albucore import add_weighted, get_num_channels
from pydantic import AfterValidator, Field, field_validator
from typing_extensions import Annotated

from albumentations.augmentations.domain_adaptation_functional import (
import albumentations.augmentations.geometric.functional as fgeometric
from albumentations.augmentations.domain_adaptation.functional import (
adapt_pixel_distribution,
apply_histogram,
fourier_domain_adaptation,
)
from albumentations.augmentations.utils import read_rgb_image
from albumentations.core.composition import Compose
from albumentations.core.pydantic import ZeroOneRangeType, check_01, nondecreasing
from albumentations.core.transforms_interface import BaseTransformInitSchema, ImageOnlyTransform
from albumentations.core.transforms_interface import BaseTransformInitSchema, BasicTransform, ImageOnlyTransform
from albumentations.core.types import ScaleFloatType

__all__ = [
"HistogramMatching",
"FDA",
"PixelDistributionAdaptation",
"TemplateTransform",
]

MAX_BETA_LIMIT = 0.5
Expand Down Expand Up @@ -339,3 +343,194 @@ def get_transform_init_args_names(self) -> tuple[str, str, str, str]:
def to_dict_private(self) -> dict[str, Any]:
msg = "PixelDistributionAdaptation can not be serialized."
raise NotImplementedError(msg)


class TemplateTransform(ImageOnlyTransform):
"""Apply blending of input image with specified templates.

This transform overlays one or more template images onto the input image using alpha blending.
It allows for creating complex composite images or simulating various visual effects.

Args:
templates (numpy array | list[np.ndarray]): Images to use as templates for the transform.
If a single numpy array is provided, it will be used as the only template.
If a list of numpy arrays is provided, one will be randomly chosen for each application.

img_weight (tuple[float, float] | float): Weight of the original image in the blend.
If a single float, that value will always be used.
If a tuple (min, max), the weight will be randomly sampled from the range [min, max) for each application.
To use a fixed weight, use (weight, weight).
Default: (0.5, 0.5).

template_transform (A.Compose | None): A composition of Albumentations transforms to apply to the template
before blending.
This should be an instance of A.Compose containing one or more Albumentations transforms.
Default: None.

name (str | None): Name of the transform instance. Used for serialization purposes.
Default: None.

p (float): Probability of applying the transform. Default: 0.5.

Targets:
image

Image types:
uint8, float32

Number of channels:
Any

Note:
- The template(s) must have the same number of channels as the input image or be single-channel.
- If a single-channel template is used with a multi-channel image, the template will be replicated across
all channels.
- The template(s) will be resized to match the input image size if they differ.
- To make this transform serializable, provide a name when initializing it.

Mathematical Formulation:
Given:
- I: Input image
- T: Template image
- w_i: Weight of input image (sampled from img_weight)

The blended image B is computed as:

B = w_i * I + (1 - w_i) * T

Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> template = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)

# Apply template transform with a single template
>>> transform = A.TemplateTransform(templates=template, name="my_template_transform", p=1.0)
>>> blended_image = transform(image=image)['image']

# Apply template transform with multiple templates and custom weights
>>> templates = [np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8) for _ in range(3)]
>>> transform = A.TemplateTransform(
... templates=templates,
... img_weight=(0.3, 0.7),
... name="multi_template_transform",
... p=1.0
... )
>>> blended_image = transform(image=image)['image']

# Apply template transform with additional transforms on the template
>>> template_transform = A.Compose([A.RandomBrightnessContrast(p=1)])
>>> transform = A.TemplateTransform(
... templates=template,
... img_weight=0.6,
... template_transform=template_transform,
... name="transformed_template",
... p=1.0
... )
>>> blended_image = transform(image=image)['image']

References:
- Alpha compositing: https://en.wikipedia.org/wiki/Alpha_compositing
- Image blending: https://en.wikipedia.org/wiki/Image_blending
"""

class InitSchema(BaseTransformInitSchema):
templates: np.ndarray | Sequence[np.ndarray]
img_weight: ZeroOneRangeType
template_weight: ZeroOneRangeType | None = Field(
deprecated="Template_weight is deprecated. Computed automatically as (1 - img_weight)",
)
template_transform: Compose | BasicTransform | None = None
name: str | None

@field_validator("templates")
@classmethod
def validate_templates(cls, v: np.ndarray | list[np.ndarray]) -> list[np.ndarray]:
if isinstance(v, np.ndarray):
return [v]
if isinstance(v, list):
if not all(isinstance(item, np.ndarray) for item in v):
msg = "All templates must be numpy arrays."
raise ValueError(msg)
return v
msg = "Templates must be a numpy array or a list of numpy arrays."
raise TypeError(msg)

def __init__(
self,
templates: np.ndarray | list[np.ndarray],
img_weight: ScaleFloatType = (0.5, 0.5),
template_weight: None = None,
template_transform: Compose | BasicTransform | None = None,
name: str | None = None,
always_apply: bool | None = None,
p: float = 0.5,
):
super().__init__(p=p, always_apply=always_apply)
self.templates = templates
self.img_weight = cast(Tuple[float, float], img_weight)
self.template_transform = template_transform
self.name = name

def apply(
self,
img: np.ndarray,
template: np.ndarray,
img_weight: float,
**params: Any,
) -> np.ndarray:
if img_weight == 0:
return template
if img_weight == 1:
return img

return add_weighted(img, img_weight, template, 1 - img_weight)

def get_params(self) -> dict[str, float]:
return {
"img_weight": random.uniform(*self.img_weight),
}

def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
img = data["image"] if "image" in data else data["images"][0]

template = random.choice(self.templates)

if self.template_transform is not None:
template = self.template_transform(image=template)["image"]

if get_num_channels(template) not in [1, get_num_channels(img)]:
msg = (
"Template must be a single channel or "
"has the same number of channels as input "
f"image ({get_num_channels(img)}), got {get_num_channels(template)}"
)
raise ValueError(msg)

if template.dtype != img.dtype:
msg = "Image and template must be the same image type"
raise ValueError(msg)

if img.shape[:2] != template.shape[:2]:
template = fgeometric.resize(template, img.shape[:2], interpolation=cv2.INTER_AREA)

if get_num_channels(template) == 1 and get_num_channels(img) > 1:
template = np.stack((template,) * get_num_channels(img), axis=-1)

# in order to support grayscale image with dummy dim
template = template.reshape(img.shape)

return {"template": template}

@classmethod
def is_serializable(cls) -> bool:
return False

def to_dict_private(self) -> dict[str, Any]:
if self.name is None:
msg = (
"To make a TemplateTransform serializable you should provide the `name` argument, "
"e.g. `TemplateTransform(name='my_transform', ...)`."
)
raise ValueError(msg)
return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name}
Loading
Loading