Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove scikit image #1988

Merged
merged 6 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions albumentations/augmentations/domain_adaptation/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

import cv2
import numpy as np
from albucore import add_weighted, clip, clipped, from_float, get_num_channels, preserve_channel_dim, to_float
from skimage.exposure import match_histograms
from albucore import add_weighted, clip, clipped, from_float, get_num_channels, preserve_channel_dim, to_float, uint8_io
from typing_extensions import Protocol

import albumentations.augmentations.geometric.functional as fgeometric
from albumentations.augmentations.utils import PCA
from albumentations.core.types import MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS
from albumentations.core.types import MONO_CHANNEL_DIMENSIONS

__all__ = [
"fourier_domain_adaptation",
Expand Down Expand Up @@ -345,16 +344,9 @@ def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: f
Note:
- If the input and reference images have different sizes, the reference image
will be resized to match the input image's dimensions.
- The function uses `match_histograms` from scikit-image for the core histogram matching.
- The function uses a custom implementation of histogram matching based on OpenCV and NumPy.
- The @clipped and @preserve_channel_dim decorators ensure the output is within
the valid range and maintains the original number of dimensions.

Example:
>>> import numpy as np
>>> from albumentations.augmentations.domain_adaptation_functional import apply_histogram
>>> input_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> reference_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> result = apply_histogram(input_image, reference_image, blend_ratio=0.7)
"""
# Resize reference image only if necessary
if img.shape[:2] != reference_image.shape[:2]:
Expand All @@ -364,11 +356,66 @@ def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: f
reference_image = np.squeeze(reference_image)

# Match histograms between the images
matched = match_histograms(
img,
reference_image,
channel_axis=2 if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and img.shape[2] > 1 else None,
)
matched = match_histograms(img, reference_image)

# Blend the original image and the matched image
return add_weighted(matched, blend_ratio, img, 1 - blend_ratio)


@uint8_io
@preserve_channel_dim
def match_histograms(image: np.ndarray, reference: np.ndarray) -> np.ndarray:
"""Adjust an image so that its cumulative histogram matches that of another.

The adjustment is applied separately for each channel.

Args:
image: Input image. Can be gray-scale or in color.
reference: Image to match histogram of. Must have the same number of channels as image.
channel_axis: If None, the image is assumed to be a grayscale (single channel) image.
Otherwise, this parameter indicates which axis of the array corresponds to channels.

Returns:
np.ndarray: Transformed input image.

Raises:
ValueError: Thrown when the number of channels in the input image and the reference differ.
"""
if reference.dtype != np.uint8:
reference = from_float(reference, np.uint8)

if image.ndim != reference.ndim:
raise ValueError("Image and reference must have the same number of dimensions.")

# Expand dimensions for grayscale images
if image.ndim == 2: # noqa: PLR2004
image = np.expand_dims(image, axis=-1)
if reference.ndim == 2: # noqa: PLR2004
reference = np.expand_dims(reference, axis=-1)

matched = np.empty(image.shape, dtype=np.uint8)

num_channels = image.shape[-1]

for channel in range(num_channels):
matched_channel = _match_cumulative_cdf(image[..., channel], reference[..., channel]).astype(np.uint8)
matched[..., channel] = matched_channel

return matched


def _match_cumulative_cdf(source: np.ndarray, template: np.ndarray) -> np.ndarray:
src_lookup = source.reshape(-1)
src_counts = np.bincount(src_lookup)
tmpl_counts = np.bincount(template.reshape(-1))

# omit values where the count was 0
tmpl_values = np.nonzero(tmpl_counts)[0]
tmpl_counts = tmpl_counts[tmpl_values]

# calculate normalized quantiles for each array
src_quantiles = np.cumsum(src_counts) / source.size
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size

interp_a_values = np.interp(src_quantiles, tmpl_quantiles, tmpl_values)
return interp_a_values[src_lookup].reshape(source.shape).astype(np.uint8)
44 changes: 44 additions & 0 deletions albumentations/augmentations/dropout/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import cv2
import numpy as np
from albucore import MAX_VALUES_BY_DTYPE, is_grayscale_image, preserve_channel_dim
from typing_extensions import Literal
Expand Down Expand Up @@ -410,3 +411,46 @@ def mask_dropout_bboxes(
def mask_dropout_keypoints(keypoints: np.ndarray, dropout_mask: np.ndarray) -> np.ndarray:
keep_indices = np.array([not dropout_mask[int(kp[1]), int(kp[0])] for kp in keypoints])
return keypoints[keep_indices]


def label(mask: np.ndarray, return_num: bool = False, connectivity: int = 2) -> np.ndarray | tuple[np.ndarray, int]:
ternaus marked this conversation as resolved.
Show resolved Hide resolved
"""Label connected regions of an integer array.

This function uses OpenCV's connectedComponents under the hood but mimics
the behavior of scikit-image's label function.

Args:
mask (np.ndarray): The array to label. Must be of integer type.
return_num (bool): If True, return the number of labels (default: False).
connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel
as a neighbor. Accepted values are 1 or 2. Default is 2.

Returns:
np.ndarray | tuple[np.ndarray, int]: Labeled array, where all connected regions are
assigned the same integer value. If return_num is True, it also returns the number of labels.
"""
# Create a copy of the original mask
labeled = np.zeros_like(mask, dtype=np.int32)

# Get unique non-zero values from the original mask
unique_values = np.unique(mask[mask != 0])

# Label each unique value separately
next_label = 1
for value in unique_values:
binary_mask = (mask == value).astype(np.uint8)

# Set connectivity for OpenCV (4 or 8)
cv2_connectivity = 4 if connectivity == 1 else 8

# Use OpenCV's connectedComponents
num_labels, labels = cv2.connectedComponents(binary_mask, connectivity=cv2_connectivity)

# Assign new labels
for i in range(1, num_labels):
labeled[labels == i] = next_label
next_label += 1

num_labels = next_label - 1

return (labeled, num_labels) if return_num else labeled
3 changes: 1 addition & 2 deletions albumentations/augmentations/dropout/mask_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import cv2
import numpy as np
from skimage.measure import label

import albumentations.augmentations.dropout.functional as fdropout
from albumentations.core.bbox_utils import BboxProcessor, denormalize_bboxes, normalize_bboxes
Expand Down Expand Up @@ -98,7 +97,7 @@ def targets_as_params(self) -> list[str]:
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
mask = data["mask"]

label_image, num_labels = label(mask, return_num=True)
label_image, num_labels = fdropout.label(mask, return_num=True)

if num_labels == 0:
dropout_mask = None
Expand Down
66 changes: 64 additions & 2 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,7 @@ def adjust_hue_torchvision(img: np.ndarray, factor: float) -> np.ndarray:
return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)


@uint8_io
@preserve_channel_dim
def superpixels(
image: np.ndarray,
Expand All @@ -1505,11 +1506,10 @@ def superpixels(
new_height, new_width = int(height * scale), int(width * scale)
image = fgeometric.resize(image, (new_height, new_width), interpolation)

segments = skimage.segmentation.slic(
segments = slic(
image,
n_segments=n_segments,
compactness=10,
channel_axis=-1 if image.ndim > MONO_CHANNEL_DIMENSIONS else None,
)

min_value = 0
Expand Down Expand Up @@ -1937,3 +1937,65 @@ def swap_tiles_on_keypoints(
new_keypoints[not_in_any_tile] = keypoints[not_in_any_tile]

return new_keypoints


def slic(image: np.ndarray, n_segments: int, compactness: float = 10.0, max_iterations: int = 10) -> np.ndarray:
ternaus marked this conversation as resolved.
Show resolved Hide resolved
"""Simple Linear Iterative Clustering (SLIC) superpixel segmentation using OpenCV and NumPy.

Args:
image (np.ndarray): Input image (2D or 3D numpy array).
n_segments (int): Approximate number of superpixels to generate.
compactness (float): Balance between color proximity and space proximity.
max_iterations (int): Maximum number of iterations for k-means.

Returns:
np.ndarray: Segmentation mask where each superpixel has a unique label.
"""
if image.ndim == MONO_CHANNEL_DIMENSIONS:
image = image[..., np.newaxis]

height, width = image.shape[:2]
num_pixels = height * width

# Normalize image to [0, 1] range
image_normalized = image.astype(np.float32) / np.max(image)

# Initialize cluster centers
grid_step = int((num_pixels / n_segments) ** 0.5)
x_range = np.arange(grid_step // 2, width, grid_step)
y_range = np.arange(grid_step // 2, height, grid_step)
centers = np.array([(x, y) for y in y_range for x in x_range if x < width and y < height])

# Initialize labels and distances
labels = -1 * np.ones((height, width), dtype=np.int32)
distances = np.full((height, width), np.inf)

for _ in range(max_iterations):
for i, center in enumerate(centers):
y, x = int(center[1]), int(center[0])

# Define the neighborhood
y_low, y_high = max(0, y - grid_step), min(height, y + grid_step + 1)
x_low, x_high = max(0, x - grid_step), min(width, x + grid_step + 1)

# Compute distances
crop = image_normalized[y_low:y_high, x_low:x_high]
color_diff = crop - image_normalized[y, x]
color_distance = np.sum(color_diff**2, axis=-1)

yy, xx = np.ogrid[y_low:y_high, x_low:x_high]
spatial_distance = ((yy - y) ** 2 + (xx - x) ** 2) / (grid_step**2)

distance = color_distance + compactness * spatial_distance

mask = distance < distances[y_low:y_high, x_low:x_high]
distances[y_low:y_high, x_low:x_high][mask] = distance[mask]
labels[y_low:y_high, x_low:x_high][mask] = i

# Update centers
for i in range(len(centers)):
mask = labels == i
if np.any(mask):
centers[i] = np.mean(np.argwhere(mask), axis=0)[::-1]

return labels
6 changes: 0 additions & 6 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3944,8 +3944,6 @@ def get_transform_init_args_names(self) -> tuple[str, str]:
class Superpixels(ImageOnlyTransform):
"""Transform images partially/completely to their superpixel representation.

This implementation uses skimage's version of the SLIC (Simple Linear Iterative Clustering) algorithm.

Args:
p_replace (tuple[float, float] | float): Defines for any segment the probability that the pixels within that
segment are replaced by their average color (otherwise, the pixels are not changed).
Expand Down Expand Up @@ -4030,10 +4028,6 @@ class Superpixels(ImageOnlyTransform):
... p=1.0
... )
>>> augmented_image = transform(image=image)['image']

References:
- SLIC Superpixels: https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic
- "SLIC Superpixels Compared to State-of-the-art Superpixel Methods" by Radhakrishna Achanta, et al.
"""

class InitSchema(BaseTransformInitSchema):
Expand Down
64 changes: 64 additions & 0 deletions tests/test_domain_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import pytest

from albumentations.augmentations.domain_adaptation.functional import PCA, MinMaxScaler, StandardScaler, apply_histogram
import numpy as np
import pytest
from skimage.exposure import match_histograms as skimage_match_histograms
from skimage.metrics import structural_similarity as ssim
from albumentations.augmentations.domain_adaptation.functional import match_histograms as our_match_histograms


@pytest.mark.parametrize(
Expand Down Expand Up @@ -356,3 +361,62 @@ def test_apply_histogram_identity():
result = apply_histogram(img, img, blend_ratio)

np.testing.assert_array_almost_equal(result, img)

def generate_random_image(shape, dtype=np.uint8):
ternaus marked this conversation as resolved.
Show resolved Hide resolved
if dtype == np.uint8:
return np.random.randint(0, 256, shape, dtype=dtype)
else: # Assume float32
return np.random.rand(*shape).astype(dtype)


@pytest.mark.parametrize("shape, channel_axis", [
((100, 100, 1), -1), # Grayscale uint8
((100, 100, 3), -1), # RGB uint8
((100, 100, 4), -1), # RGBA uint8
])
ternaus marked this conversation as resolved.
Show resolved Hide resolved
def test_match_histograms(shape, channel_axis):
dtype = np.uint8
source = generate_random_image(shape, dtype)
reference = generate_random_image(shape, dtype)

our_result = our_match_histograms(source, reference)
skimage_result = skimage_match_histograms(source, reference, channel_axis=channel_axis)

# Check shape and dtype
assert our_result.shape == skimage_result.shape
assert our_result.dtype == source.dtype

# Compare histograms

for channel in range(shape[channel_axis]):
our_hist, _ = np.histogram(np.take(our_result, channel, axis=channel_axis).ravel(), bins=256, range=(0, 1 if dtype == np.float32 else 255))
skimage_hist, _ = np.histogram(np.take(skimage_result, channel, axis=channel_axis).ravel(), bins=256, range=(0, 1 if dtype == np.float32 else 255))
np.testing.assert_allclose(our_hist, skimage_hist, rtol=1e-5, atol=1)
ternaus marked this conversation as resolved.
Show resolved Hide resolved

# Compare mean and standard deviation
np.testing.assert_allclose(our_result.mean(), skimage_result.mean(), rtol=1e-5)
np.testing.assert_allclose(our_result.std(), skimage_result.std(), rtol=1e-5)

# Compare structural similarity
similarity = ssim(our_result, skimage_result, channel_axis=channel_axis, data_range=255)
assert similarity > 0.99, f"SSIM should be > 0.99, got {similarity}"

# Compare pixel-wise differences
max_diff = np.max(np.abs(our_result.astype(np.float64) - skimage_result.astype(np.float64)))
assert max_diff <= 1e-5, f"Max pixel-wise difference should be <= 1e-5, got {max_diff}"


@pytest.mark.parametrize("shape, dtype", [
((100, 100), np.uint8),
((100, 100, 3), np.uint8),
])
def test_match_histograms_identity(shape, dtype):
image = generate_random_image(shape, dtype)
result = our_match_histograms(image, image)
np.testing.assert_allclose(result, image, rtol=1e-5, atol=1e-8)

def test_match_histograms_different_shapes():
source = generate_random_image((100, 100, 3), np.uint8)
reference = generate_random_image((50, 50, 3), np.uint8)
result = our_match_histograms(source, reference)
assert result.shape == source.shape
ternaus marked this conversation as resolved.
Show resolved Hide resolved
Loading