From 0e1c698f7e48fcd909cb3e5c04742b33675b6f60 Mon Sep 17 00:00:00 2001 From: kshitij Date: Thu, 4 Apr 2024 10:52:24 +0200 Subject: [PATCH] Added LLaVA --- .../create_unified_interleaved_dataset.py | 105 ++++++++++++++---- .../interleaved_text_image/dataloader.py | 6 +- 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py b/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py index e6cc53fcf..82f8cda14 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py @@ -35,6 +35,8 @@ from megatron.tokenizer.tokenizer import build_tokenizer +IMAGE_SIZE = 336 + IMAGE_UNDERSTANDING_TEXT_VARIANTS = [ ("Describe this image", " "), ("The caption of the image", "is "), @@ -533,6 +535,66 @@ def __iter__(self): sample = {"images": [None], "text": [doc]} yield sample +class LLaVADataset(IterableDataset): + def __init__(self, path, group, image_start_text, image_end_text): + start, end = group + fpath = f"{path}/{{{str(start).zfill(5)}..{str(end).zfill(5)}}}.tar" + self.dataset = iter( + wds.WebDataset(fpath) + .decode("pilrgb") + .to_tuple("jpg;png;jpeg;webp", "json", "__key__", "__url__") + ) + + self.seed_parquet_folder = "/p/fastdata/mmlaion/hummingbird/temp_llava_seed" + self.image_start_text = image_start_text + self.image_end_text = image_end_text + self.current_parquet_path = None + self.current_loaded_parquet = None + + def __iter__(self): + while True: + try: + image, metadata, key, url = next(self.dataset) + id = str(metadata["id"]) + basename = os.path.basename(url) + parquet_shard = int(os.path.splitext(basename)[0]) + parquet_path = f"{self.seed_parquet_folder}/{str(parquet_shard).zfill(5)}.parquet" + if self.current_parquet_path != parquet_path: + self.current_parquet_path = parquet_path + self.current_loaded_parquet = pd.read_parquet(parquet_path) + self.current_loaded_parquet.set_index("id", inplace=True) + seed_tokens = np.frombuffer(self.current_loaded_parquet.loc[[id]]["seed_tokens"].iloc[0], dtype=np.int64) + split_images = split_images_fn(image, IMAGE_SIZE) + conversations = metadata["conversations"] + + for i in range(0, len(conversations), 2): + + question = conversations[i]["value"] + answer = conversations[i + 1]["value"] + if "" in question: + question = question.replace("", "") + + text_portion = self.image_start_text + "".join( + [f"<|seed_{seed_token}|>" for seed_token in seed_tokens] + ) + + image_list = [None] + text_list = [text_portion] + for split_image in split_images: + image_list.append(split_image) + text_list.append(None) + + image_list.append(None) + text_list.append(self.image_end_text + question + " " + answer) + yield {"images": image_list, "text": text_list} + except StopIteration: + break + except ValueError as e: + print(f"Error encountered: {e}. Skipping this datapoint.") + continue + except Exception as e: + print(f"Unexpected Error encountered: {e}. Skipping this datapoint.") + continue class OBELICSDataset(IterableDataset): def __init__(self, group, image_start_text, image_end_text): @@ -626,7 +688,7 @@ def __iter__(self): image = Image.open(io.BytesIO(image_data)) if image.mode != "RGB": image = image.convert("RGB") - split_images = split_images_fn(image, 224) + split_images = split_images_fn(image, IMAGE_SIZE) final_texts[-1] = ( final_texts[-1].rstrip() @@ -681,7 +743,7 @@ def __iter__(self): try: image, metadata, key, url, text = next(self.dataset) - split_images = split_images_fn(image, 224) + split_images = split_images_fn(image, IMAGE_SIZE) if text is None: print("key 'text' not found in the sample, skipping this datapoint") @@ -829,7 +891,7 @@ def __iter__(self): text = sample_json["caption"] image = torchvision.transforms.functional.resize( - image, [224, 224], interpolation=InterpolationMode.BICUBIC + image, [IMAGE_SIZE, IMAGE_SIZE], interpolation=InterpolationMode.BICUBIC ) current_path = url @@ -884,16 +946,16 @@ def __iter__(self): top_left = [ref_exp[2], ref_exp[3]] bottom_right = [ref_exp[4], ref_exp[5]] - top_left[0] = int(top_left[0] * 224) - top_left[1] = int(top_left[1] * 224) - bottom_right[0] = int(bottom_right[0] * 224) - bottom_right[1] = int(bottom_right[1] * 224) + top_left[0] = int(top_left[0] * IMAGE_SIZE) + top_left[1] = int(top_left[1] * IMAGE_SIZE) + bottom_right[0] = int(bottom_right[0] * IMAGE_SIZE) + bottom_right[1] = int(bottom_right[1] * IMAGE_SIZE) top_left_bin = self.get_bin_number( - 224, 224, 32, top_left[0], top_left[1] + IMAGE_SIZE, IMAGE_SIZE, 32, top_left[0], top_left[1] ) bottom_right_bin = self.get_bin_number( - 224, 224, 32, bottom_right[0], bottom_right[1] + IMAGE_SIZE, IMAGE_SIZE, 32, bottom_right[0], bottom_right[1] ) referring_expression = text[start_text:end_text] @@ -969,6 +1031,8 @@ def build_interleaved_multimodal_dataset( dataset = GritDatasetGeneration(group, image_start_text, image_end_text) elif dataset_type == "obelics": dataset = OBELICSDataset(group, image_start_text, image_end_text) + elif dataset_type == "llava": + dataset = LLaVADataset(path, group, image_start_text, image_end_text) else: raise ValueError(f"Dataset {dataset} not recognized.") @@ -1073,7 +1137,7 @@ def data_writer(data_queue, args, index): total_samples = 0 total_images = 0 while True: - print("The queue size is", data_queue.qsize()) + # print("The queue size is", data_queue.qsize()) try: sample = data_queue.get(timeout=100) total_samples += 1 @@ -1089,8 +1153,8 @@ def data_writer(data_queue, args, index): # time_taken = end_time - start_time # print(f"\nTime taken for 1000 samples: {time_taken} seconds") # start_time = time.time() # reset start time - if total_samples > 1000: - break + # if total_samples > 1000: + # break except multiprocessing.queues.Empty: print(f"\rNo more data to write. Exiting. {index}") break @@ -1107,8 +1171,8 @@ def get_dataset_groups(start_ind: int, end_ind: int, groups: int): Iterator[Tuple[str, int, int]]: Each argument tuple """ group_size = (end_ind - start_ind) // groups - for group_start in range(start_ind, end_ind, group_size): - yield (group_start, group_start + group_size) + for group_start in range(start_ind, end_ind+1, group_size): + yield (group_start, group_start + group_size - 1) def main(args: Namespace) -> None: @@ -1180,11 +1244,11 @@ def parse_args() -> Namespace: ) parser.add_argument("--queue_size", type=int, default=5000) parser.add_argument("--split", type=str, default="train") - parser.add_argument("--num_groups", type=int, default=1) - parser.add_argument("--workers", type=int, default=1) # 44 # 80 - parser.add_argument("--num_writers", type=int, default=1) # 2 - parser.add_argument("--start_ind", type=int, default=19) - parser.add_argument("--end_ind", type=int, default=21) # 150 + parser.add_argument("--num_groups", type=int, default=22) + parser.add_argument("--workers", type=int, default=22) # 44 # 80 + parser.add_argument("--num_writers", type=int, default=26) # 2 + parser.add_argument("--start_ind", type=int, default=0) + parser.add_argument("--end_ind", type=int, default=62) # 150 parser.add_argument("--tokenizer_type", type=str, required=False, default=None) parser.add_argument("--vocab_file", type=str, required=False, default=None) parser.add_argument("--merge_file", type=str, required=False, default=None) @@ -1235,4 +1299,7 @@ def parse_args() -> Namespace: 20 groups 5e8, 22, 26 python megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py --path /p/fastdata/mmlaion/hummingbird/grit --dataset_type grit --compression zstd --concat_tokens 2048 --tokenizer_type HFTokenizer --vocab_file /p/project/ccstdl/gupta6/multimodal/20B_tokenizer.json --out_root /p/fastdata/mmlaion/hummingbird/hummingbird_dataset_final/grit_val + + +python megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py --path /p/fastdata/mmlaion/llava_v1_5_mix665k --dataset_type llava --compression zstd --concat_tokens 2048 --tokenizer_type HFTokenizer --vocab_file /p/project/ccstdl/gupta6/multimodal/20B_tokenizer.json --out_root /p/fastdata/mmlaion/hummingbird/hummingbird_dataset_final/test_llava """ diff --git a/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py b/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py index 9ec996c8d..df94baf35 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py @@ -26,7 +26,7 @@ # from megatron.data.streaming_dataset.interleaved_text_image.create_interleaved_dataset import simple_encoding, ListPIL, PickleEncoding -from megatron.data.streaming_dataset.interleaved_text_image.create_unified_interleaved_dataset import ListPIL +from megatron.data.streaming_dataset.interleaved_text_image.create_unified_interleaved_dataset import ListPIL, IMAGE_SIZE # _encodings['pickleencoding'] = PickleEncoding _encodings['listpil'] = ListPIL # _encodings['simple_encoding'] = simple_encoding @@ -191,10 +191,10 @@ def __getitem__(self, idx: int): images = np.stack(images) else: images = np.array([]) - vision_input = images.reshape(-1, 224, 224, 3) + vision_input = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3) is_vision_empty = vision_input.shape[0] == 0 if is_vision_empty: - vision_input = np.zeros((1, 224, 224, 3), dtype=np.uint8) + vision_input = np.zeros((1, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8) vision_input = torch.from_numpy(vision_input).to(torch.int64) vision_input = vision_input.unsqueeze(1) # TODO: Fix for num_frames > 1