Skip to content

Commit

Permalink
Highly accurate boundaries segmentation using BASNet to keras 3.0 (Te…
Browse files Browse the repository at this point in the history
…nsorflow backend only) (#1942)

* Keras 3 migration basnet segmentation

* Fix format issues
  • Loading branch information
chunduriv authored Oct 23, 2024
1 parent 07b6a7e commit 9fdad44
Showing 1 changed file with 66 additions and 53 deletions.
119 changes: 66 additions & 53 deletions examples/vision/basnet_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Highly accurate boundaries segmentation using BASNet
Author: [Hamid Ali](https://github.com/hamidriasat)
Date created: 2023/05/30
Last modified: 2023/07/13
Last modified: 2024/10/02
Description: Boundaries aware segmentation model trained on the DUTS dataset.
Accelerator: GPU
"""
Expand Down Expand Up @@ -38,14 +38,16 @@
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

import keras_cv
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, backend
import keras
from keras import layers, ops

"""
## Define Hyperparameters
Expand All @@ -58,10 +60,10 @@
DATA_DIR = "./DUTS-TE/"

"""
## Create TensorFlow Dataset
## Create `PyDataset`s
We will use `load_paths()` to load and split 140 paths into train and validation set, and
`load_dataset()` to convert paths into `tf.data.Dataset` object.
convert paths into `PyDataset` object.
"""


Expand All @@ -72,51 +74,64 @@ def load_paths(path, split_ratio):
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])


def read_image(path, size, mode):
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
x = keras.utils.img_to_array(x)
x = (x / 255.0).astype(np.float32)
return x


def preprocess(x_batch, y_batch, img_size, out_classes):
def f(_x, _y):
_x, _y = _x.decode(), _y.decode()
_x = read_image(_x, (img_size, img_size), mode="rgb") # image
_y = read_image(_y, (img_size, img_size), mode="grayscale") # mask
return _x, _y

images, masks = tf.numpy_function(f, [x_batch, y_batch], [tf.float32, tf.float32])
images.set_shape([img_size, img_size, 3])
masks.set_shape([img_size, img_size, out_classes])
return images, masks


def load_dataset(image_paths, mask_paths, img_size, out_classes, batch, shuffle=True):
dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
if shuffle:
dataset = dataset.cache().shuffle(buffer_size=1000)
dataset = dataset.map(
lambda x, y: preprocess(x, y, img_size, out_classes),
num_parallel_calls=tf.data.AUTOTUNE,
)
dataset = dataset.batch(batch)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
class Dataset(keras.utils.PyDataset):
def __init__(
self,
image_paths,
mask_paths,
img_size,
out_classes,
batch,
shuffle=True,
**kwargs,
):
if shuffle:
perm = np.random.permutation(len(image_paths))
image_paths = [image_paths[i] for i in perm]
mask_paths = [mask_paths[i] for i in perm]
self.image_paths = image_paths
self.mask_paths = mask_paths
self.img_size = img_size
self.out_classes = out_classes
self.batch_size = batch
super().__init__(*kwargs)

def __len__(self):
return len(self.image_paths) // self.batch_size

def __getitem__(self, idx):
batch_x, batch_y = [], []
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
x, y = self.preprocess(
self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes
)
batch_x.append(x)
batch_y.append(y)
batch_x = np.stack(batch_x, axis=0)
batch_y = np.stack(batch_y, axis=0)
return batch_x, batch_y

def read_image(self, path, size, mode):
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
x = keras.utils.img_to_array(x)
x = (x / 255.0).astype(np.float32)
return x

def preprocess(self, x_batch, y_batch, img_size, out_classes):
images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
return images, masks


train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)

train_dataset = load_dataset(
train_dataset = Dataset(
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
)
val_dataset = load_dataset(
val_dataset = Dataset(
val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False
)

print(f"Train Dataset: {train_dataset}")
print(f"Validation Dataset: {val_dataset}")

"""
## Visualize Data
"""
Expand All @@ -133,7 +148,7 @@ def display(display_list):
plt.show()


for image, mask in val_dataset.take(1):
for (image, mask), _ in zip(val_dataset, range(1)):
display([image[0], mask[0]])

"""
Expand Down Expand Up @@ -265,7 +280,7 @@ def basnet_predict(input_shape, out_classes):
decoder_blocks = []
for i in reversed(range(num_stages)):
if i != (num_stages - 1): # Except first, scale other decoder stages.
shape = keras.backend.int_shape(x)
shape = x.shape
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)

x = layers.concatenate([encoder_blocks[i], x], axis=-1)
Expand Down Expand Up @@ -318,7 +333,7 @@ def basnet_rrm(base_model, out_classes):

# -------------Decoder--------------
for i in reversed(range(num_stages)):
shape = keras.backend.int_shape(x)
shape = x.shape
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
x = convolution_block(x, filters=filters)
Expand All @@ -345,7 +360,7 @@ def basnet(input_shape, out_classes):
# Refinement model.
refine_model = basnet_rrm(predict_model, out_classes)

output = [refine_model.output] # Combine outputs.
output = refine_model.outputs # Combine outputs.
output.extend(predict_model.output)

output = [layers.Activation("sigmoid")(_) for _ in output] # Activations.
Expand Down Expand Up @@ -382,18 +397,16 @@ def calculate_iou(
y_pred,
):
"""Calculate intersection over union (IoU) between images."""
intersection = backend.sum(backend.abs(y_true * y_pred), axis=[1, 2, 3])
union = backend.sum(y_true, [1, 2, 3]) + backend.sum(y_pred, [1, 2, 3])
intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])
union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])
union = union - intersection
return backend.mean(
(intersection + self.smooth) / (union + self.smooth), axis=0
)
return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)

def call(self, y_true, y_pred):
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)

ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
ssim_loss = backend.mean(1 - ssim_value + self.smooth, axis=0)
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)

iou_value = self.iou_value(y_true, y_pred)
iou_loss = 1 - iou_value
Expand All @@ -412,7 +425,7 @@ def call(self, y_true, y_pred):
basnet_model.compile(
loss=BasnetLoss(),
optimizer=optimizer,
metrics=[keras.metrics.MeanAbsoluteError(name="mae")],
metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
)

"""
Expand Down Expand Up @@ -453,6 +466,6 @@ def normalize_output(prediction):
### Make Predictions
"""

for image, mask in val_dataset.take(1):
for (image, mask), _ in zip(val_dataset, range(1)):
pred_mask = basnet_model.predict(image)
display([image[0], mask[0], normalize_output(pred_mask[0][0])])

0 comments on commit 9fdad44

Please sign in to comment.