diff --git a/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py b/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py index 3aa7ab743..e6cc53fcf 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py @@ -157,6 +157,19 @@ def check_image( DIMENSIONS = [(1, 1), (2, 2), (1, 2), (1, 3), (1, 4), (2, 1), (3, 1), (4, 1)] MAX_SPLITS = 4 +def zigzag_order(scaled_image_np, new_height, new_width, box_length): + sub_images = [] + + for y in range(0, new_height, box_length): + if y // box_length % 2 == 0: # Even rows are left-to-right + x_values = range(0, new_width, box_length) + else: # Odd rows are right-to-left + x_values = range(new_width - box_length, -box_length, -box_length) + + for x in x_values: + sub_images.append(scaled_image_np[y:y+box_length, x:x+box_length]) + + return sub_images def split_images_fn(image, box_length): @@ -183,20 +196,24 @@ def split_images_fn(image, box_length): # Choose the first configuration (the one with lowest area difference and aspect ratio difference) chosen_configuration = sorted_configurations[0]["dim"] - new_width, new_height = chosen_configuration - scaled_image = torchvision.transforms.functional.resize( - image, [new_height, new_width], interpolation=InterpolationMode.BICUBIC - ) - - # Splitting the image into a grid of sub-images of size: box_length X box_length - scaled_image_np = np.array(scaled_image) - sub_images = [ - scaled_image_np[i : i + box_length, j : j + box_length] - for i in range(0, new_height, box_length) - for j in range(0, new_width, box_length) - ] - # Converting numpy arrays back to images - sub_images = [Image.fromarray(sub_image) for sub_image in sub_images] + if chosen_configuration != (1*box_length, 1*box_length): + if random.random() < 0.5: # Split less than 50% of the time + new_width, new_height = chosen_configuration + scaled_image = torchvision.transforms.functional.resize( + image, [new_height, new_width], interpolation=InterpolationMode.BICUBIC + ) + # Splitting the image into a grid of sub-images of size: box_length X box_length + scaled_image_np = np.array(scaled_image) + sub_images = zigzag_order(scaled_image_np, new_height, new_width, box_length) + + # Converting numpy arrays back to images + sub_images = [Image.fromarray(sub_image) for sub_image in sub_images] + else: + sub_images = [] + else: + sub_images = [] + + sub_images.append(torchvision.transforms.functional.resize(image, [box_length, box_length], interpolation=InterpolationMode.BICUBIC)) # Returns a list of image objects return sub_images