Skip to content

Commit

Permalink
fixed interleaved streaming loader and creator
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Oct 18, 2023
1 parent 6999b7c commit 25c9aad
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from streaming.base.format.mds.encodings import Encoding, _encodings

class ImageEncoding:
class ImageEncoding(Encoding):
def encode(self, images: List[Image.Image]) -> bytes:
bytes_arr = []
for image in images:
Expand Down Expand Up @@ -112,6 +112,25 @@ def decode(self, data: bytes) -> np.ndarray:

_encodings['pickleencoding'] = PickleEncoding

class simple_encoding(Encoding):
def encode(self, data: List[Image.Image]) -> bytes:
# Read all images into numpy array
data = map(lambda x: np.array(x), data)
data = np.stack(list(data))
assert data.shape == (len(data), 256, 256, 3), f'Expected shape (N, 256, 256, 3), got {data.shape}'
for img in data:
assert img.dtype == np.uint8, f'Expected dtype np.uint8, got {img.dtype}'
return data.tobytes()

def decode(self, data: bytes) -> np.ndarray:
# convert bytes to numpy array
data = np.frombuffer(data, dtype=np.uint8)
# reshape to original shape
data = data.reshape(-1, 256, 256, 3)
return data

_encodings['simple_encoding'] = simple_encoding

class NoConcatDataset(IterableDataset):
"""An IterableDataset that returns text samples for MDSWriter.
Expand Down Expand Up @@ -174,6 +193,10 @@ def __init__(
self.after_image_extra_tokens = after_image_extra_tokens
self.bos_text = bos_text
self.eos_text = eos_text
self.pad_token_id = self.tokenizer("<|padding|>",
truncation=False,
padding=False,
add_special_tokens=False)['input_ids'][0]
self.should_wrap = not no_wrap

self.bos_tokens = self.tokenizer(self.bos_text,
Expand All @@ -194,7 +217,16 @@ def __init__(
warnings.warn(
f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\
, instead we got {self.eos_tokens}. Quit if this was in error.')


self.image_start_token = self.tokenizer(self.image_start_text,
truncation=False,
padding=False,
add_special_tokens=False)['input_ids'][0]
self.image_end_token = self.tokenizer(self.image_end_text,
truncation=False,
padding=False,
add_special_tokens=False)['input_ids'][0]

eos_text_provided = self.eos_text != ''
bos_text_provided = self.bos_text != ''
test_text = self.tokenizer('')
Expand Down Expand Up @@ -235,7 +267,7 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
self.text_buffer.append(self.eos_tokens)
self.image_buffer.extend(images)
self.image_buffer.append(None)

#We want to add text and image to our upcoming output (setup), and remove them from the buffer.
while True:

Expand All @@ -253,11 +285,11 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
if text != None and image == None:
# Fix saving the missing/original part of the text
if current_length + len(text) > self.max_length: # Too long, that's fine for text, just grab what we can
current_length = self.max_length
text_append = text[:self.max_length-current_length] # Changes the actual list in the thing
text[0] = text[self.max_length-current_length:]
self.text_buffer[0] = text[self.max_length-current_length:]
# We do NOT pop an image here because we haven't finished the current text
# We also naturally do not pop text.
current_length = self.max_length

else: # Not greater, remove entire text and entire image
text_append = self.text_buffer.pop(0)
Expand All @@ -273,32 +305,62 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:

else: # So this includes that EOS case...
curr_image.extend([None, self.image_buffer.pop(0), None])
curr_text.extend([self.image_start_text, self.text_buffer.pop(0), self.image_end_text])
curr_text.extend([self.image_start_token, self.text_buffer.pop(0), self.image_end_token])
current_length += self.image_seq_length + 2
else:
raise ValueError("Text and image are both none or both not None in the same sample")

if current_length == self.max_length:
np_text = np.array(curr_text)

# length is total number of non None tokens
text_length = len(np_text[np_text != None])
vision_length = len(np_text[np_text==None])
total_sample_length = text_length + vision_length*self.image_seq_length

if total_sample_length != self.max_length:
# Pad rest of the text tokens
np_text = np.pad(np_text, (0, self.max_length - total_sample_length), constant_values = self.pad_token_id)

text_ids = np_text[np_text != None]
text_tokens = text_ids[:-1]
text_tokens = text_ids
text_positions = torch.from_numpy(np.where(np_text != None)[0])

images = list(filter(lambda a: a != None, curr_image))
image_positions = torch.from_numpy(np.where(np_text == None)[0])
labels = text_ids[1:]
labels = np.roll(np_text, -1, axis = 0)
labels[-1] = self.pad_token_id

text_labels = labels[np_text != None]
image_labels = labels[np_text == None]

multimodal_position_ids = torch.nn.utils.rnn.pad_sequence([text_positions, image_positions], batch_first = True, padding_value = -1)
# Replace None with pad token in labels
text_labels = np.where(text_labels == None, self.pad_token_id, text_labels).astype(np.int64)
image_labels = np.where(image_labels == None, self.pad_token_id, image_labels).astype(np.int64)

print("text_id", text_tokens)
print("text_positions", text_positions)
print("image_positions", image_positions)
print("multimodal_position_ids", multimodal_position_ids)
print("labels", labels)
print("images", images)
multimodal_position_ids = torch.nn.utils.rnn.pad_sequence([text_positions, image_positions], batch_first = True, padding_value = -1) # TODO: Make this position pad id

labels = torch.nn.utils.rnn.pad_sequence([torch.from_numpy(text_labels), torch.from_numpy(image_labels)], batch_first = True, padding_value = -1)

# convert tensor to numpy array
labels = labels.numpy().tobytes()
text_tokens = text_tokens.astype(np.int64)
text_tokens = text_tokens.tobytes()
multimodal_position_ids = multimodal_position_ids.numpy().tobytes()

images = map(lambda x: np.array(x), images)
images = np.stack(list(images))
images = np.expand_dims(images, axis=1)

# print("text_id", text_tokens)
# print("text_positions", text_positions)
# print("image_positions", image_positions)
# print("multimodal_position_ids", multimodal_position_ids)
# print("labels", labels)
# print("images", images)

yield {
'images': images,
'images': images.tobytes(),
'tokens': text_tokens,
'multimodal_position_ids' : multimodal_position_ids,
'labels': labels
Expand Down Expand Up @@ -370,11 +432,29 @@ def parse_args() -> Namespace:

class ImageCaptionDataset(IterableDataset):
def __init__(self, path):
fpath = path + "/{00000..00001}.tar"
fpath = path + "/{00000..41455}.tar"
self.dataset = wds.WebDataset(fpath).decode("pilrgb").rename(image="jpg;png;jpeg;webp", text="txt").to_tuple("image", "text")

# def __iter__(self):
# for image, text in self.dataset:
# sample = {
# "images": [None, image],
# "text": [text, None]
# }
# yield sample

def __iter__(self):
for image, text in self.dataset:
data_iter = iter(self.dataset)
while True:
try:
image, text = next(data_iter)
except StopIteration:
# If StopIteration is raised, break from loop
break
except Exception as e:
print(f"Error encountered: {e}. Skipping this datapoint.")
continue

sample = {
"images": [None, image],
"text": [text, None]
Expand All @@ -390,6 +470,7 @@ def build_image_caption_dataset(
eos_text: str = '',
no_wrap: bool = False,
tokenizer: PreTrainedTokenizerBase = None,
vision_seq_length: int = 64,
) -> IterableDataset:
"""Build an IterableDataset over the HF C4 or pile source data.
Expand Down Expand Up @@ -435,7 +516,7 @@ def build_image_caption_dataset(
dataset=dataset,
tokenizer=tokenizer,
max_length=max_length,
image_seq_length=10,
image_seq_length=vision_seq_length,
bos_text=bos_text,
eos_text=eos_text,
image_start_text='hello',
Expand Down Expand Up @@ -480,13 +561,11 @@ def main(args: Namespace) -> None:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes', 'images': 'pickleencoding', 'multimodal_position_ids': 'ndarray', 'labels': 'ndarray'}
columns = {'tokens': 'bytes', 'images': 'bytes', 'multimodal_position_ids': 'bytes', 'labels': 'bytes'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
columns = {'text': 'str', 'images': 'pickleencoding'}

print('here')
columns = {'text': 'str', 'images': 'ndarray'}

# Write samples
print(f'Converting to MDS format...')
Expand All @@ -496,19 +575,26 @@ def main(args: Namespace) -> None:
print(f'It will finish at a value below 100% if tokenizing')
with MDSWriter(columns=columns,
out=os.path.join(args.out_root),
compression=args.compression, size_limit=5.12e+8) as out:
for i in range(1):
# Get samples
dataset = build_image_caption_dataset(path='/p/fastdata/mmlaion/laion-400m/LAION-400m-webdataset/data',
split=args.split,
mode=mode,
max_length=args.concat_tokens,
bos_text=args.bos_text,
eos_text=args.eos_text,
no_wrap=args.no_wrap,
tokenizer=tokenizer)
for sample in tqdm(dataset):
out.write(sample)
compression=args.compression, size_limit=1e+10) as out:
# Get samples
dataset = build_image_caption_dataset(path='/p/fastdata/mmlaion/laion-400m/LAION-400m-webdataset/data',
split=args.split,
mode=mode,
max_length=args.concat_tokens,
bos_text=args.bos_text,
eos_text=args.eos_text,
no_wrap=args.no_wrap,
tokenizer=tokenizer)
total_samples = 0
total_images = 0
for sample in tqdm(dataset):
total_samples += 1
total_images += len(sample["images"])
print(total_samples, total_images)
# simple_encoder = simple_encoding()
out.write(sample)
if total_samples >= 145:
break


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,51 @@
from itertools import islice
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

from PIL import Image
import pickle
import numpy as np
import torch
import transformers
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
from streaming import Stream, StreamingDataset, StreamingDataLoader
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
import argparse
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoImageProcessor, AutoModel

from streaming.base.format.mds.encodings import Encoding, _encodings

class PickleEncoding(Encoding):
def encode(self, data: List[Image.Image]) -> bytes:
return pickle.dumps(data)

def decode(self, data: bytes) -> np.ndarray:
data = pickle.loads(data)
# Convert PIL Images to numpy arrays
data = map(lambda x: np.array(x), data)
return np.stack(list(data))

_encodings['pickleencoding'] = PickleEncoding

class simple_encoding(Encoding):
def encode(self, data: List[Image.Image]) -> bytes:
# Read all images into numpy array
data = map(lambda x: np.array(x), data)
data = np.stack(list(data))
assert data.shape == (len(data), 256, 256, 3), f'Expected shape (N, 256, 256, 3), got {data.shape}'
return data.tobytes()

def decode(self, data: bytes) -> np.ndarray:
# convert bytes to numpy array
data = np.frombuffer(data, dtype=np.uint8)
# print(data.shape, data.reshape(-1, 256, 256, 3).shape)
# reshape to original shape
data = data.reshape(-1, 256, 256, 3)
return data

_encodings['simple_encoding'] = simple_encoding

def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase:
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
Expand Down Expand Up @@ -157,6 +191,7 @@ def __init__(self,
)
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')

# How to tokenize a text sample to a token sample
def _tokenize(self, text_sample: Mapping):
Expand Down Expand Up @@ -191,9 +226,11 @@ def __getitem__(self, idx: int):
raise RuntimeError(
'StreamingTextDataset needs samples to have a `text` or `tokens` column'
)
vision_input = torch.from_numpy(sample.get('vision_input', None).copy())
multimodal_position_ids = torch.from_numpy(sample.get('multimodal_position_ids', None).copy())
labels = torch.from_numpy(sample.get('labels', None).copy())
vision_input = np.frombuffer(sample.get('images', None), dtype=np.uint8).copy().reshape(-1, 256, 256, 3)
vision_input = self.processor(vision_input, return_tensors="pt")
vision_input = vision_input["pixel_values"].to(torch.int64).unsqueeze(1) # TODO: Fix for num_frames > 1
multimodal_position_ids = torch.from_numpy(np.frombuffer(sample.get('multimodal_position_ids', None), dtype=np.int64).copy()).reshape(2, -1)
labels = torch.from_numpy(np.frombuffer(sample.get('labels', None), dtype=np.int64).copy()).reshape(2, -1)
return (token_sample, vision_input, multimodal_position_ids, labels)


Expand Down Expand Up @@ -237,7 +274,6 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
if self.take_transpose:
# Apply transpose to each example in batch parallely using map function
examples = list(map(lambda x: x.transpose(0, 1), examples))

batch = torch.nn.utils.rnn.pad_sequence(examples, batch_first=True, padding_value=self.pad_token_id)

if self.take_transpose:
Expand Down Expand Up @@ -286,11 +322,11 @@ def build_interleaved_dataloader(

text_collate_fn = TextNeoXCollateWrapper(text_collate_fn)

vision_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (num_vision, H, W, C)
vision_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (timesteps, num_vision, H, W, C)

multimodal_position_ids_collate_fn = PaddedCollateWrapper(pad_token_id=-1, take_transpose=True) # Each sample: (num_modalities, max_seq_length)

label_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (num_modalities, max_seq_length)
label_collate_fn = PaddedCollateWrapper(pad_token_id=-1, take_transpose=True) # Each sample: (num_modalities, max_seq_length)

collate_fn = MultimodalCollateWrapper(text_collator=text_collate_fn,
vision_collator=vision_collate_fn,
Expand All @@ -299,7 +335,7 @@ def build_interleaved_dataloader(
multimodal_position_ids_collator=multimodal_position_ids_collate_fn,
label_collator=label_collate_fn)

return DataLoader(
return StreamingDataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand Down

0 comments on commit 25c9aad

Please sign in to comment.