Skip to content

Commit

Permalink
Added LLaVA
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Apr 4, 2024
1 parent 10475e5 commit 0e1c698
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

from megatron.tokenizer.tokenizer import build_tokenizer

IMAGE_SIZE = 336

IMAGE_UNDERSTANDING_TEXT_VARIANTS = [
("Describe this image", " "),
("The caption of the image", "is "),
Expand Down Expand Up @@ -533,6 +535,66 @@ def __iter__(self):
sample = {"images": [None], "text": [doc]}
yield sample

class LLaVADataset(IterableDataset):
def __init__(self, path, group, image_start_text, image_end_text):
start, end = group
fpath = f"{path}/{{{str(start).zfill(5)}..{str(end).zfill(5)}}}.tar"
self.dataset = iter(
wds.WebDataset(fpath)
.decode("pilrgb")
.to_tuple("jpg;png;jpeg;webp", "json", "__key__", "__url__")
)

self.seed_parquet_folder = "/p/fastdata/mmlaion/hummingbird/temp_llava_seed"
self.image_start_text = image_start_text
self.image_end_text = image_end_text
self.current_parquet_path = None
self.current_loaded_parquet = None

def __iter__(self):
while True:
try:
image, metadata, key, url = next(self.dataset)
id = str(metadata["id"])
basename = os.path.basename(url)
parquet_shard = int(os.path.splitext(basename)[0])
parquet_path = f"{self.seed_parquet_folder}/{str(parquet_shard).zfill(5)}.parquet"
if self.current_parquet_path != parquet_path:
self.current_parquet_path = parquet_path
self.current_loaded_parquet = pd.read_parquet(parquet_path)
self.current_loaded_parquet.set_index("id", inplace=True)
seed_tokens = np.frombuffer(self.current_loaded_parquet.loc[[id]]["seed_tokens"].iloc[0], dtype=np.int64)
split_images = split_images_fn(image, IMAGE_SIZE)
conversations = metadata["conversations"]

for i in range(0, len(conversations), 2):

question = conversations[i]["value"]
answer = conversations[i + 1]["value"]
if "<image>" in question:
question = question.replace("<image>", "")

text_portion = self.image_start_text + "".join(
[f"<|seed_{seed_token}|>" for seed_token in seed_tokens]
)

image_list = [None]
text_list = [text_portion]
for split_image in split_images:
image_list.append(split_image)
text_list.append(None)

image_list.append(None)
text_list.append(self.image_end_text + question + " " + answer)
yield {"images": image_list, "text": text_list}
except StopIteration:
break
except ValueError as e:
print(f"Error encountered: {e}. Skipping this datapoint.")
continue
except Exception as e:
print(f"Unexpected Error encountered: {e}. Skipping this datapoint.")
continue

class OBELICSDataset(IterableDataset):
def __init__(self, group, image_start_text, image_end_text):
Expand Down Expand Up @@ -626,7 +688,7 @@ def __iter__(self):
image = Image.open(io.BytesIO(image_data))
if image.mode != "RGB":
image = image.convert("RGB")
split_images = split_images_fn(image, 224)
split_images = split_images_fn(image, IMAGE_SIZE)

final_texts[-1] = (
final_texts[-1].rstrip()
Expand Down Expand Up @@ -681,7 +743,7 @@ def __iter__(self):
try:
image, metadata, key, url, text = next(self.dataset)

split_images = split_images_fn(image, 224)
split_images = split_images_fn(image, IMAGE_SIZE)

if text is None:
print("key 'text' not found in the sample, skipping this datapoint")
Expand Down Expand Up @@ -829,7 +891,7 @@ def __iter__(self):
text = sample_json["caption"]

image = torchvision.transforms.functional.resize(
image, [224, 224], interpolation=InterpolationMode.BICUBIC
image, [IMAGE_SIZE, IMAGE_SIZE], interpolation=InterpolationMode.BICUBIC
)

current_path = url
Expand Down Expand Up @@ -884,16 +946,16 @@ def __iter__(self):
top_left = [ref_exp[2], ref_exp[3]]
bottom_right = [ref_exp[4], ref_exp[5]]

top_left[0] = int(top_left[0] * 224)
top_left[1] = int(top_left[1] * 224)
bottom_right[0] = int(bottom_right[0] * 224)
bottom_right[1] = int(bottom_right[1] * 224)
top_left[0] = int(top_left[0] * IMAGE_SIZE)
top_left[1] = int(top_left[1] * IMAGE_SIZE)
bottom_right[0] = int(bottom_right[0] * IMAGE_SIZE)
bottom_right[1] = int(bottom_right[1] * IMAGE_SIZE)

top_left_bin = self.get_bin_number(
224, 224, 32, top_left[0], top_left[1]
IMAGE_SIZE, IMAGE_SIZE, 32, top_left[0], top_left[1]
)
bottom_right_bin = self.get_bin_number(
224, 224, 32, bottom_right[0], bottom_right[1]
IMAGE_SIZE, IMAGE_SIZE, 32, bottom_right[0], bottom_right[1]
)

referring_expression = text[start_text:end_text]
Expand Down Expand Up @@ -969,6 +1031,8 @@ def build_interleaved_multimodal_dataset(
dataset = GritDatasetGeneration(group, image_start_text, image_end_text)
elif dataset_type == "obelics":
dataset = OBELICSDataset(group, image_start_text, image_end_text)
elif dataset_type == "llava":
dataset = LLaVADataset(path, group, image_start_text, image_end_text)
else:
raise ValueError(f"Dataset {dataset} not recognized.")

Expand Down Expand Up @@ -1073,7 +1137,7 @@ def data_writer(data_queue, args, index):
total_samples = 0
total_images = 0
while True:
print("The queue size is", data_queue.qsize())
# print("The queue size is", data_queue.qsize())
try:
sample = data_queue.get(timeout=100)
total_samples += 1
Expand All @@ -1089,8 +1153,8 @@ def data_writer(data_queue, args, index):
# time_taken = end_time - start_time
# print(f"\nTime taken for 1000 samples: {time_taken} seconds")
# start_time = time.time() # reset start time
if total_samples > 1000:
break
# if total_samples > 1000:
# break
except multiprocessing.queues.Empty:
print(f"\rNo more data to write. Exiting. {index}")
break
Expand All @@ -1107,8 +1171,8 @@ def get_dataset_groups(start_ind: int, end_ind: int, groups: int):
Iterator[Tuple[str, int, int]]: Each argument tuple
"""
group_size = (end_ind - start_ind) // groups
for group_start in range(start_ind, end_ind, group_size):
yield (group_start, group_start + group_size)
for group_start in range(start_ind, end_ind+1, group_size):
yield (group_start, group_start + group_size - 1)


def main(args: Namespace) -> None:
Expand Down Expand Up @@ -1180,11 +1244,11 @@ def parse_args() -> Namespace:
)
parser.add_argument("--queue_size", type=int, default=5000)
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--num_groups", type=int, default=1)
parser.add_argument("--workers", type=int, default=1) # 44 # 80
parser.add_argument("--num_writers", type=int, default=1) # 2
parser.add_argument("--start_ind", type=int, default=19)
parser.add_argument("--end_ind", type=int, default=21) # 150
parser.add_argument("--num_groups", type=int, default=22)
parser.add_argument("--workers", type=int, default=22) # 44 # 80
parser.add_argument("--num_writers", type=int, default=26) # 2
parser.add_argument("--start_ind", type=int, default=0)
parser.add_argument("--end_ind", type=int, default=62) # 150
parser.add_argument("--tokenizer_type", type=str, required=False, default=None)
parser.add_argument("--vocab_file", type=str, required=False, default=None)
parser.add_argument("--merge_file", type=str, required=False, default=None)
Expand Down Expand Up @@ -1235,4 +1299,7 @@ def parse_args() -> Namespace:
20 groups
5e8, 22, 26
python megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py --path /p/fastdata/mmlaion/hummingbird/grit --dataset_type grit --compression zstd --concat_tokens 2048 --tokenizer_type HFTokenizer --vocab_file /p/project/ccstdl/gupta6/multimodal/20B_tokenizer.json --out_root /p/fastdata/mmlaion/hummingbird/hummingbird_dataset_final/grit_val
python megatron/data/streaming_dataset/interleaved_text_image/create_unified_interleaved_dataset.py --path /p/fastdata/mmlaion/llava_v1_5_mix665k --dataset_type llava --compression zstd --concat_tokens 2048 --tokenizer_type HFTokenizer --vocab_file /p/project/ccstdl/gupta6/multimodal/20B_tokenizer.json --out_root /p/fastdata/mmlaion/hummingbird/hummingbird_dataset_final/test_llava
"""
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from megatron.data.streaming_dataset.interleaved_text_image.create_interleaved_dataset import simple_encoding, ListPIL, PickleEncoding

from megatron.data.streaming_dataset.interleaved_text_image.create_unified_interleaved_dataset import ListPIL
from megatron.data.streaming_dataset.interleaved_text_image.create_unified_interleaved_dataset import ListPIL, IMAGE_SIZE
# _encodings['pickleencoding'] = PickleEncoding
_encodings['listpil'] = ListPIL
# _encodings['simple_encoding'] = simple_encoding
Expand Down Expand Up @@ -191,10 +191,10 @@ def __getitem__(self, idx: int):
images = np.stack(images)
else:
images = np.array([])
vision_input = images.reshape(-1, 224, 224, 3)
vision_input = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)
is_vision_empty = vision_input.shape[0] == 0
if is_vision_empty:
vision_input = np.zeros((1, 224, 224, 3), dtype=np.uint8)
vision_input = np.zeros((1, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)

vision_input = torch.from_numpy(vision_input).to(torch.int64)
vision_input = vision_input.unsqueeze(1) # TODO: Fix for num_frames > 1
Expand Down

0 comments on commit 0e1c698

Please sign in to comment.