Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cut out objects with transparent background in the segmenter #47

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions processing/segmenter/planktoscope/segmenter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def _manual_median(self, images_array):
def _save_image(self, image, path):
PIL.Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).save(path)

def _save_object_image(self, image, path):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))

def _save_mask(self, mask, path):
PIL.Image.fromarray(mask).save(path)

Expand Down Expand Up @@ -209,6 +212,7 @@ def _open_and_apply_flat(self, filepath, flat_ref):
# logger.debug(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
# Read images
image = cv2.imread(filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# print(image)

# logger.debug(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
Expand Down Expand Up @@ -431,13 +435,43 @@ def pipe_full(conn):
return len(w) == 0

img_object = io.BytesIO()
PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).save(
img_object, format="JPEG"
PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)).save(
img_object, format="PNG"
)
logger.debug("Sending the object in the pipe!")
if not pipe_full(planktoscope.segmenter.streamer.sender):
planktoscope.segmenter.streamer.sender.send(img_object)

def _crop_and_apply_mask(self, img, mask):
"""
Crops the image to the bounding box of the given mask and applies the mask as an alpha channel to create a cutout with transparency.

Args:
image (numpy.ndarray): The original image from which objects are to be extracted. Expected to be in RGB format.
mask (numpy.ndarray): The binary mask indicating the object's location in the image. Expected to be a 2D array of the same height and width as the image.

Returns:
numpy.ndarray: The cropped image with an alpha channel applied based on the mask. The output image will be in BGRA format, with the alpha channel representing the mask.
"""
# Find the bounding box of the mask
y_indices, x_indices = np.where(mask)
if len(y_indices) == 0 or len(x_indices) == 0:
return None # No object found in the mask
y_min, y_max = y_indices.min(), y_indices.max()
x_min, x_max = x_indices.min(), x_indices.max()

# Crop the image and mask to the bounding box
cropped_image = img[y_min:y_max+1, x_min:x_max+1]
cropped_mask = mask[y_min:y_max+1, x_min:x_max+1]

# Add alpha channel to the cropped image
cropped_image_bgra = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGRA)

# Use the cropped mask as the alpha channel of the cropped image
cropped_image_bgra[:, :, 3] = cropped_mask.astype(np.uint8) * 255

return cropped_image_bgra

def _slice_image(self, img, name, mask, start_count=0):
"""Slice a given image using give mask

Expand All @@ -451,29 +485,6 @@ def _slice_image(self, img, name, mask, start_count=0):
tuple: (Number of saved objects, original number of objects before size filtering)
"""

def __augment_slice(dim_slice, max_dims, size=10):
# transform tuple in list
dim_slice = list(dim_slice)
# dim_slice[0] is the vertical component
# dim_slice[1] is the horizontal component
# dim_slice[1].start,dim_slice[0].start is the top left corner
for i in range(2):
if dim_slice[i].start < size:
dim_slice[i] = slice(0, dim_slice[i].stop)
else:
dim_slice[i] = slice(dim_slice[i].start - size, dim_slice[i].stop)

# dim_slice[1].stop,dim_slice[0].stop is the bottom right corner
for i in range(2):
if dim_slice[i].stop + size == max_dims[i]:
dim_slice[i] = slice(dim_slice[i].start, max_dims[i])
else:
dim_slice[i] = slice(dim_slice[i].start, dim_slice[i].stop + size)

# transform back list in tuple
dim_slice = tuple(dim_slice)
return dim_slice

minMesh = self.__global_metadata.get("acq_minimum_mesh", 20) # microns
minESD = minMesh * 2
minArea = math.pi * (minESD / 2) * (minESD / 2)
Expand All @@ -498,18 +509,19 @@ def __augment_slice(dim_slice, max_dims, size=10):
f'{{"object_id":"{region.label}"}}',
)

# First extract to get all the metadata about the image
obj_image = img[region.slice]
colors = self._get_color_info(obj_image, region.filled_image)
metadata = self._extract_metadata_from_regionprop(region)
# Extract the object and get all the metadata about the image
individual_mask = (labels == region.label).astype(np.uint8)
obj_image = self._crop_and_apply_mask(img, individual_mask)
region_image = img[region.slice]
if obj_image is not None:
colors = self._get_color_info(region_image, region.filled_image)
metadata = self._extract_metadata_from_regionprop(region)

# Second extract to get a bigger image for saving
obj_image = img[__augment_slice(region.slice, labels.shape, 10)]
object_id = f"{name}_{i}"
object_fn = os.path.join(self.__working_obj_path, f"{object_id}.jpg")
object_id = f"{name}_{i}"
object_fn = os.path.join(self.__working_obj_path, f"{object_id}.png")

self._save_image(obj_image, object_fn)
self._stream(obj_image)
self._save_object_image(obj_image, object_fn)
self._stream(obj_image)

if self.__save_debug_img:
self._save_mask(
Expand Down