Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add LLaST(WIP) #837

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--exclude=xtuner/model/transformers_models/*"]
args: ["--exclude=xtuner/model/transformers_models/*,xtuner/evaluation/metrics/sacrebleu.py"]
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
Expand Down
3 changes: 2 additions & 1 deletion xtuner/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
load_intern_repo_untokenized_dataset)
from .internvl_dataset import InternVL_V1_5_Dataset
from .json_dataset import load_json_file
from .llast import LLaSTDataset
from .llava import LLaVADataset
from .modelscope import process_ms_dataset
from .moss_sft import MOSSSFTDataset
Expand All @@ -25,5 +26,5 @@
'load_intern_repo_tokenized_dataset',
'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset',
'load_json_file', 'InternVL_V1_5_Dataset'
'load_json_file', 'InternVL_V1_5_Dataset', 'LLaSTDataset'
]
5 changes: 4 additions & 1 deletion xtuner/dataset/collate_fns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .default_collate_fn import default_collate_fn
from .llast_collate_fn import llast_audiomask_mel_collate_fn
from .mmlu_collate_fn import mmlu_collate_fn

__all__ = ['default_collate_fn', 'mmlu_collate_fn']
__all__ = [
'default_collate_fn', 'mmlu_collate_fn', 'llast_audiomask_mel_collate_fn'
]
60 changes: 60 additions & 0 deletions xtuner/dataset/collate_fns/llast_collate_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) LLaST. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence

import torch
from torch.nn.utils.rnn import pad_sequence

from xtuner.utils import (DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX,
LLAST_AUDIO_PADDING_TOKEN_INDEX)


def llast_audiomask_mel_collate_fn(
instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False) -> Dict[str, torch.Tensor]:
"""Add audio tokens and conduct padding operation."""
input_ids = []
labels = []
feats_lens = []
has_audio = any(inst.get('audio_tokens') is not None for inst in instances)

if has_audio:
audio_tokens = []
for example in instances:
input_ids.append(torch.tensor(example['input_ids']))
labels.append(torch.tensor(example['labels']))
if has_audio:
audio_tokens.append(example['audio_tokens'])
feats_lens.append(torch.tensor(example['audio_lens']))
if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX)
# padding audio tokens
padded_audio_tokens = pad_sequence(
audio_tokens,
batch_first=True,
padding_value=LLAST_AUDIO_PADDING_TOKEN_INDEX)

else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)
padded_audio_tokens = torch.stack(audio_tokens)

data_dict = {
'input_ids': input_ids,
'attention_mask': input_ids.ne(pad_index),
'labels': labels
}

if has_audio:
audio_lens = torch.stack(feats_lens)
data_dict['audio_tokens'] = padded_audio_tokens
data_dict['audio_lens'] = audio_lens

if return_hf_format:
return data_dict
else:
return {'data': data_dict, 'data_samples': instances}
19 changes: 15 additions & 4 deletions xtuner/dataset/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def add_template_to_dataset(dataset, template_map_fn, map_num_proc):


def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
input_ids_with_output, remove_unused_columns,
map_num_proc):
with_audio_token, input_ids_with_output,
remove_unused_columns, map_num_proc):
assert (tokenizer is not None) and (max_length is not None), \
f'({tokenizer}, {max_length})'
if isinstance(tokenizer, dict) or isinstance(
Expand All @@ -78,6 +78,7 @@ def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
tokenizer=tokenizer,
max_length=max_length,
with_image_token=with_image_token,
with_audio_token=with_audio_token,
input_ids_with_output=input_ids_with_output),
remove_columns=list(dataset.column_names)
if remove_unused_columns else None,
Expand Down Expand Up @@ -112,6 +113,7 @@ def process(dataset,
use_varlen_attn=False,
input_ids_with_output=True,
with_image_token=False,
with_audio_token=False,
map_num_proc=32):
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
dataset.
Expand Down Expand Up @@ -153,6 +155,9 @@ def process(dataset,
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
IMAGE_TOKEN_INDEX. Typically set it to True during the training
of VLM.
with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to
LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the
training of SLM.
map_num_proc: Max number of processes when mapping the dataset.
"""
if use_varlen_attn:
Expand Down Expand Up @@ -197,7 +202,8 @@ def process(dataset,

if do_dataset_tokenization:
dataset = tokenize_dataset(dataset, tokenizer, max_length,
with_image_token, input_ids_with_output,
with_image_token, with_audio_token,
input_ids_with_output,
remove_unused_columns, map_num_proc)

if input_ids_with_output:
Expand All @@ -213,7 +219,7 @@ def process(dataset,
shuffle_before_pack, map_num_proc)

# add 'length'
dataset = dataset.map(get_lengths, num_proc=map_num_proc)
dataset = dataset.map(get_lengths, num_proc=1)
setattr(dataset, 'length', dataset['length'])

return dataset
Expand All @@ -234,6 +240,7 @@ def process_hf_dataset(dataset,
use_varlen_attn=False,
input_ids_with_output=True,
with_image_token=False,
with_audio_token=False,
map_num_proc=32):
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
dataset.
Expand Down Expand Up @@ -275,6 +282,9 @@ def process_hf_dataset(dataset,
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
IMAGE_TOKEN_INDEX. Typically set it to True during the training
of VLM.
with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to
LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the
training of SLM.
map_num_proc: Max number of processes when mapping the dataset.
"""
kwargs = dict(
Expand All @@ -293,6 +303,7 @@ def process_hf_dataset(dataset,
use_varlen_attn=use_varlen_attn,
input_ids_with_output=input_ids_with_output,
with_image_token=with_image_token,
with_audio_token=with_audio_token,
map_num_proc=map_num_proc)
if not (dist.is_available() and dist.is_initialized()):
return process(**kwargs)
Expand Down
Loading
Loading