diff --git a/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py b/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py index dc2fa24fc..3fb19701f 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py @@ -36,7 +36,7 @@ from streaming.base.format.mds.encodings import Encoding, _encodings -class ImageEncoding: +class ImageEncoding(Encoding): def encode(self, images: List[Image.Image]) -> bytes: bytes_arr = [] for image in images: @@ -112,6 +112,25 @@ def decode(self, data: bytes) -> np.ndarray: _encodings['pickleencoding'] = PickleEncoding +class simple_encoding(Encoding): + def encode(self, data: List[Image.Image]) -> bytes: + # Read all images into numpy array + data = map(lambda x: np.array(x), data) + data = np.stack(list(data)) + assert data.shape == (len(data), 256, 256, 3), f'Expected shape (N, 256, 256, 3), got {data.shape}' + for img in data: + assert img.dtype == np.uint8, f'Expected dtype np.uint8, got {img.dtype}' + return data.tobytes() + + def decode(self, data: bytes) -> np.ndarray: + # convert bytes to numpy array + data = np.frombuffer(data, dtype=np.uint8) + # reshape to original shape + data = data.reshape(-1, 256, 256, 3) + return data + +_encodings['simple_encoding'] = simple_encoding + class NoConcatDataset(IterableDataset): """An IterableDataset that returns text samples for MDSWriter. @@ -174,6 +193,10 @@ def __init__( self.after_image_extra_tokens = after_image_extra_tokens self.bos_text = bos_text self.eos_text = eos_text + self.pad_token_id = self.tokenizer("<|padding|>", + truncation=False, + padding=False, + add_special_tokens=False)['input_ids'][0] self.should_wrap = not no_wrap self.bos_tokens = self.tokenizer(self.bos_text, @@ -194,7 +217,16 @@ def __init__( warnings.warn( f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\ , instead we got {self.eos_tokens}. Quit if this was in error.') - + + self.image_start_token = self.tokenizer(self.image_start_text, + truncation=False, + padding=False, + add_special_tokens=False)['input_ids'][0] + self.image_end_token = self.tokenizer(self.image_end_text, + truncation=False, + padding=False, + add_special_tokens=False)['input_ids'][0] + eos_text_provided = self.eos_text != '' bos_text_provided = self.bos_text != '' test_text = self.tokenizer('') @@ -235,7 +267,7 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: self.text_buffer.append(self.eos_tokens) self.image_buffer.extend(images) self.image_buffer.append(None) - + #We want to add text and image to our upcoming output (setup), and remove them from the buffer. while True: @@ -253,11 +285,11 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: if text != None and image == None: # Fix saving the missing/original part of the text if current_length + len(text) > self.max_length: # Too long, that's fine for text, just grab what we can - current_length = self.max_length text_append = text[:self.max_length-current_length] # Changes the actual list in the thing - text[0] = text[self.max_length-current_length:] + self.text_buffer[0] = text[self.max_length-current_length:] # We do NOT pop an image here because we haven't finished the current text # We also naturally do not pop text. + current_length = self.max_length else: # Not greater, remove entire text and entire image text_append = self.text_buffer.pop(0) @@ -273,32 +305,62 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: else: # So this includes that EOS case... curr_image.extend([None, self.image_buffer.pop(0), None]) - curr_text.extend([self.image_start_text, self.text_buffer.pop(0), self.image_end_text]) + curr_text.extend([self.image_start_token, self.text_buffer.pop(0), self.image_end_token]) current_length += self.image_seq_length + 2 else: raise ValueError("Text and image are both none or both not None in the same sample") if current_length == self.max_length: np_text = np.array(curr_text) + + # length is total number of non None tokens + text_length = len(np_text[np_text != None]) + vision_length = len(np_text[np_text==None]) + total_sample_length = text_length + vision_length*self.image_seq_length + + if total_sample_length != self.max_length: + # Pad rest of the text tokens + np_text = np.pad(np_text, (0, self.max_length - total_sample_length), constant_values = self.pad_token_id) + text_ids = np_text[np_text != None] - text_tokens = text_ids[:-1] + text_tokens = text_ids text_positions = torch.from_numpy(np.where(np_text != None)[0]) images = list(filter(lambda a: a != None, curr_image)) image_positions = torch.from_numpy(np.where(np_text == None)[0]) - labels = text_ids[1:] + labels = np.roll(np_text, -1, axis = 0) + labels[-1] = self.pad_token_id + + text_labels = labels[np_text != None] + image_labels = labels[np_text == None] - multimodal_position_ids = torch.nn.utils.rnn.pad_sequence([text_positions, image_positions], batch_first = True, padding_value = -1) + # Replace None with pad token in labels + text_labels = np.where(text_labels == None, self.pad_token_id, text_labels).astype(np.int64) + image_labels = np.where(image_labels == None, self.pad_token_id, image_labels).astype(np.int64) - print("text_id", text_tokens) - print("text_positions", text_positions) - print("image_positions", image_positions) - print("multimodal_position_ids", multimodal_position_ids) - print("labels", labels) - print("images", images) + multimodal_position_ids = torch.nn.utils.rnn.pad_sequence([text_positions, image_positions], batch_first = True, padding_value = -1) # TODO: Make this position pad id + + labels = torch.nn.utils.rnn.pad_sequence([torch.from_numpy(text_labels), torch.from_numpy(image_labels)], batch_first = True, padding_value = -1) + + # convert tensor to numpy array + labels = labels.numpy().tobytes() + text_tokens = text_tokens.astype(np.int64) + text_tokens = text_tokens.tobytes() + multimodal_position_ids = multimodal_position_ids.numpy().tobytes() + + images = map(lambda x: np.array(x), images) + images = np.stack(list(images)) + images = np.expand_dims(images, axis=1) + + # print("text_id", text_tokens) + # print("text_positions", text_positions) + # print("image_positions", image_positions) + # print("multimodal_position_ids", multimodal_position_ids) + # print("labels", labels) + # print("images", images) yield { - 'images': images, + 'images': images.tobytes(), 'tokens': text_tokens, 'multimodal_position_ids' : multimodal_position_ids, 'labels': labels @@ -370,11 +432,29 @@ def parse_args() -> Namespace: class ImageCaptionDataset(IterableDataset): def __init__(self, path): - fpath = path + "/{00000..00001}.tar" + fpath = path + "/{00000..41455}.tar" self.dataset = wds.WebDataset(fpath).decode("pilrgb").rename(image="jpg;png;jpeg;webp", text="txt").to_tuple("image", "text") + # def __iter__(self): + # for image, text in self.dataset: + # sample = { + # "images": [None, image], + # "text": [text, None] + # } + # yield sample + def __iter__(self): - for image, text in self.dataset: + data_iter = iter(self.dataset) + while True: + try: + image, text = next(data_iter) + except StopIteration: + # If StopIteration is raised, break from loop + break + except Exception as e: + print(f"Error encountered: {e}. Skipping this datapoint.") + continue + sample = { "images": [None, image], "text": [text, None] @@ -390,6 +470,7 @@ def build_image_caption_dataset( eos_text: str = '', no_wrap: bool = False, tokenizer: PreTrainedTokenizerBase = None, + vision_seq_length: int = 64, ) -> IterableDataset: """Build an IterableDataset over the HF C4 or pile source data. @@ -435,7 +516,7 @@ def build_image_caption_dataset( dataset=dataset, tokenizer=tokenizer, max_length=max_length, - image_seq_length=10, + image_seq_length=vision_seq_length, bos_text=bos_text, eos_text=eos_text, image_start_text='hello', @@ -480,13 +561,11 @@ def main(args: Namespace) -> None: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) # we will enforce length, so suppress warnings about sequences too long for the model tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'bytes', 'images': 'pickleencoding', 'multimodal_position_ids': 'ndarray', 'labels': 'ndarray'} + columns = {'tokens': 'bytes', 'images': 'bytes', 'multimodal_position_ids': 'bytes', 'labels': 'bytes'} else: mode = ConcatMode.NO_CONCAT tokenizer = None - columns = {'text': 'str', 'images': 'pickleencoding'} - - print('here') + columns = {'text': 'str', 'images': 'ndarray'} # Write samples print(f'Converting to MDS format...') @@ -496,19 +575,26 @@ def main(args: Namespace) -> None: print(f'It will finish at a value below 100% if tokenizing') with MDSWriter(columns=columns, out=os.path.join(args.out_root), - compression=args.compression, size_limit=5.12e+8) as out: - for i in range(1): - # Get samples - dataset = build_image_caption_dataset(path='/p/fastdata/mmlaion/laion-400m/LAION-400m-webdataset/data', - split=args.split, - mode=mode, - max_length=args.concat_tokens, - bos_text=args.bos_text, - eos_text=args.eos_text, - no_wrap=args.no_wrap, - tokenizer=tokenizer) - for sample in tqdm(dataset): - out.write(sample) + compression=args.compression, size_limit=1e+10) as out: + # Get samples + dataset = build_image_caption_dataset(path='/p/fastdata/mmlaion/laion-400m/LAION-400m-webdataset/data', + split=args.split, + mode=mode, + max_length=args.concat_tokens, + bos_text=args.bos_text, + eos_text=args.eos_text, + no_wrap=args.no_wrap, + tokenizer=tokenizer) + total_samples = 0 + total_images = 0 + for sample in tqdm(dataset): + total_samples += 1 + total_images += len(sample["images"]) + print(total_samples, total_images) + # simple_encoder = simple_encoding() + out.write(sample) + if total_samples >= 145: + break if __name__ == '__main__': diff --git a/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py b/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py index fd87a9184..423e56992 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py @@ -7,17 +7,51 @@ from itertools import islice from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union +from PIL import Image +import pickle import numpy as np import torch import transformers from omegaconf import DictConfig from omegaconf import OmegaConf as om -from streaming import Stream, StreamingDataset +from streaming import Stream, StreamingDataset, StreamingDataLoader from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase import argparse from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoImageProcessor, AutoModel +from streaming.base.format.mds.encodings import Encoding, _encodings + +class PickleEncoding(Encoding): + def encode(self, data: List[Image.Image]) -> bytes: + return pickle.dumps(data) + + def decode(self, data: bytes) -> np.ndarray: + data = pickle.loads(data) + # Convert PIL Images to numpy arrays + data = map(lambda x: np.array(x), data) + return np.stack(list(data)) + +_encodings['pickleencoding'] = PickleEncoding + +class simple_encoding(Encoding): + def encode(self, data: List[Image.Image]) -> bytes: + # Read all images into numpy array + data = map(lambda x: np.array(x), data) + data = np.stack(list(data)) + assert data.shape == (len(data), 256, 256, 3), f'Expected shape (N, 256, 256, 3), got {data.shape}' + return data.tobytes() + + def decode(self, data: bytes) -> np.ndarray: + # convert bytes to numpy array + data = np.frombuffer(data, dtype=np.uint8) + # print(data.shape, data.reshape(-1, 256, 256, 3).shape) + # reshape to original shape + data = data.reshape(-1, 256, 256, 3) + return data + +_encodings['simple_encoding'] = simple_encoding def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase: os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' @@ -157,6 +191,7 @@ def __init__(self, ) self.tokenizer = tokenizer self.max_seq_length = max_seq_length + self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') # How to tokenize a text sample to a token sample def _tokenize(self, text_sample: Mapping): @@ -191,9 +226,11 @@ def __getitem__(self, idx: int): raise RuntimeError( 'StreamingTextDataset needs samples to have a `text` or `tokens` column' ) - vision_input = torch.from_numpy(sample.get('vision_input', None).copy()) - multimodal_position_ids = torch.from_numpy(sample.get('multimodal_position_ids', None).copy()) - labels = torch.from_numpy(sample.get('labels', None).copy()) + vision_input = np.frombuffer(sample.get('images', None), dtype=np.uint8).copy().reshape(-1, 256, 256, 3) + vision_input = self.processor(vision_input, return_tensors="pt") + vision_input = vision_input["pixel_values"].to(torch.int64).unsqueeze(1) # TODO: Fix for num_frames > 1 + multimodal_position_ids = torch.from_numpy(np.frombuffer(sample.get('multimodal_position_ids', None), dtype=np.int64).copy()).reshape(2, -1) + labels = torch.from_numpy(np.frombuffer(sample.get('labels', None), dtype=np.int64).copy()).reshape(2, -1) return (token_sample, vision_input, multimodal_position_ids, labels) @@ -237,7 +274,6 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: if self.take_transpose: # Apply transpose to each example in batch parallely using map function examples = list(map(lambda x: x.transpose(0, 1), examples)) - batch = torch.nn.utils.rnn.pad_sequence(examples, batch_first=True, padding_value=self.pad_token_id) if self.take_transpose: @@ -286,11 +322,11 @@ def build_interleaved_dataloader( text_collate_fn = TextNeoXCollateWrapper(text_collate_fn) - vision_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (num_vision, H, W, C) + vision_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (timesteps, num_vision, H, W, C) multimodal_position_ids_collate_fn = PaddedCollateWrapper(pad_token_id=-1, take_transpose=True) # Each sample: (num_modalities, max_seq_length) - label_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (num_modalities, max_seq_length) + label_collate_fn = PaddedCollateWrapper(pad_token_id=-1, take_transpose=True) # Each sample: (num_modalities, max_seq_length) collate_fn = MultimodalCollateWrapper(text_collator=text_collate_fn, vision_collator=vision_collate_fn, @@ -299,7 +335,7 @@ def build_interleaved_dataloader( multimodal_position_ids_collator=multimodal_position_ids_collate_fn, label_collator=label_collate_fn) - return DataLoader( + return StreamingDataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size,