Skip to content

Commit

Permalink
Refactor build_tokenizer to use kwargs syntax and specify name (#532)
Browse files Browse the repository at this point in the history
Refactor build tokenizer in train/train.py and eval/eval.py to use **kwargs syntax so that config's are dictionaries, classes/functions have arguments explicitly stated in their signature, and build_*** uses **kwargs for initializing classes.
  • Loading branch information
j316chuck committed Sep 1, 2023
1 parent 68448b2 commit 186dd19
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 61 deletions.
11 changes: 4 additions & 7 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,13 +864,10 @@ def _format_tokens_for_decoder_only(
cfg = om.create(cfg)
device_batch_size = 2

tokenizer_cfg = {
'name': 'EleutherAI/gpt-neox-20b' if decoder_only else 't5-base',
'kwargs': {}
}
tokenizer_cfg['kwargs'] = {'model_max_length': cfg.dataset.max_seq_len}
tokenizer_cfg = om.create(tokenizer_cfg)
tokenizer = build_tokenizer(tokenizer_cfg)
tokenizer_name = 'EleutherAI/gpt-neox-20b' if decoder_only else 't5-base'
tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
assert isinstance(loader.dataset, StreamingTextDataset)
Expand Down
7 changes: 3 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,9 @@ def _build_collate_fn(dataset_cfg: DictConfig,
'timeout': 0
})

tokenizer_cfg = {'name': 'EleutherAI/gpt-neox-20b', 'kwargs': {}}
tokenizer_cfg['kwargs'] = {'model_max_length': cfg.dataset.max_seq_len}
tokenizer_cfg = om.create(tokenizer_cfg)
tokenizer = build_tokenizer(tokenizer_cfg)
tokenizer_name = 'EleutherAI/gpt-neox-20b'
tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

device_batch_size = 2
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
Expand Down
9 changes: 7 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Callable, Dict, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -360,7 +360,12 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
# build tokenizer
if 'tokenizer' not in cfg:
raise ValueError('config must define tokenizer')
tokenizer = build_tokenizer(cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(cfg.tokenizer,
resolve=True) # type: ignore
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

# Turn off packing for the dataloader (we want raw, pre-packed examples)
dataloader_cfg.dataset.packing_ratio = None
Expand Down
7 changes: 3 additions & 4 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,9 @@ def build_text_dataloader(
cfg = om.create(cfg)
device_batch_size = 2

tokenizer_cfg = {'name': args.tokenizer, 'kwargs': {}}
tokenizer_cfg['kwargs'] = {'model_max_length': args.max_seq_len}
tokenizer_cfg = om.create(tokenizer_cfg)
tokenizer = build_tokenizer(tokenizer_cfg)
tokenizer_name = args.tokenizer
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
tokenizer = loader.dataset.tokenizer # type: ignore
Expand Down
13 changes: 5 additions & 8 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
def build_icl_data_and_gauntlet(
icl_tasks_config: Union[str, ListConfig],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
tokenizer: AutoTokenizer,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
icl_seq_len: int,
icl_subset_num_batches: Optional[int] = None
Expand Down Expand Up @@ -153,15 +153,12 @@ def build_scheduler(name: str, scheduler_config: Dict[str, Any]):
raise ValueError(f'Not sure how to build scheduler: {name}')


def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase:
def build_tokenizer(
tokenizer_name: str,
tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase:
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

resolved_om_tokenizer_config = om.to_container(om_tokenizer_config,
resolve=True)
tokenizer_kwargs = resolved_om_tokenizer_config.get( # type: ignore
'kwargs', {})
tokenizer_name = resolved_om_tokenizer_config['name'] # type: ignore
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
**tokenizer_kwargs)

Expand All @@ -178,7 +175,7 @@ def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase:

def build_icl_evaluators(
icl_tasks: Union[str, ListConfig],
tokenizer: AutoTokenizer,
tokenizer: PreTrainedTokenizerBase,
default_max_seq_len: int,
default_batch_size: int,
destination_dir: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def pop_config(cfg: DictConfig,
if not isinstance(value, DictConfig) and not isinstance(
value, ListConfig):
raise ValueError(
f'The key: {key} has a value: {value} that cannot be \
f'The key {key} has a value of type {type(value)} that cannot be \
converted to a dict or list. Please check your yaml.'
)
return om.to_container(value)
Expand Down
7 changes: 6 additions & 1 deletion scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int],
icl_subset_num_batches: Optional[int]):
print(f'Evaluating model: {model_cfg.model_name}', flush=True)
# Build tokenizer and model
tokenizer = build_tokenizer(model_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(model_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size,
Expand Down
9 changes: 7 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def main(cfg: DictConfig) -> Trainer:

# Mandatory model training configs
model_config: DictConfig = pop_config(cfg, 'model', must_exist=True)
tokenizer_config: DictConfig = pop_config(cfg, 'tokenizer', must_exist=True)
tokenizer_config: Dict[str, Any] = pop_config(cfg,
'tokenizer',
must_exist=True,
convert=True)
optimizer_config: Dict[str, Any] = pop_config(cfg,
'optimizer',
must_exist=True,
Expand Down Expand Up @@ -416,7 +419,9 @@ def main(cfg: DictConfig) -> Trainer:
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)

# Build tokenizer
tokenizer = build_tokenizer(tokenizer_config)
tokenizer_name = tokenizer_config['name']
tokenizer_kwargs = tokenizer_config.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

# Scheduler
scheduler_name: str = scheduler_config.pop('name')
Expand Down
32 changes: 10 additions & 22 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,9 @@ def test_correct_padding(tokenizer_name: str,
})

tokenizer = build_tokenizer(
om.create({
'name': tokenizer_name,
'kwargs': {}
}))
tokenizer_name=tokenizer_name,
tokenizer_kwargs={},
)

# Dataloaders
eval_loader = build_text_dataloader(
Expand Down Expand Up @@ -202,12 +201,8 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
expected_keys += ['sequence_id']

tokenizer = build_tokenizer(
om.create({
'name': tokenizer_name,
'kwargs': {
'model_max_length': max_seq_len
}
}))
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len})

loader = build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size)
Expand Down Expand Up @@ -258,12 +253,8 @@ def test_finetuning_dataloader(decoder_only_format: bool,
cfg = om.create(cfg)

tokenizer = build_tokenizer(
om.create({
'name': tokenizer_name,
'kwargs': {
'model_max_length': max_seq_len
}
}))
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len})

device_batch_size = 2

Expand Down Expand Up @@ -332,12 +323,9 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
cfg = om.create(cfg)

tokenizer = build_tokenizer(
om.create({
'name': tokenizer_name,
'kwargs': {
'model_max_length': max_seq_len
}
}))
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len},
)

expected_keys = ['input_ids', 'attention_mask', 'labels']
expected_keys += ['bidirectional_mask']
Expand Down
7 changes: 6 additions & 1 deletion tests/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ def test_hf_config_override(
test_cfg.precision = 'fp16'
test_cfg.model.attn_config = {'attn_impl': 'torch', 'alibi': True}

tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)

Expand Down
9 changes: 8 additions & 1 deletion tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict

import pytest
from composer.core.precision import get_precision_context
from composer.utils import get_device, reproducibility
Expand Down Expand Up @@ -44,7 +46,12 @@ def test_init_hfhub_mpt(device: str, attn_impl: str):
})

# build tokenizer
tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

# build model
model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
Expand Down
49 changes: 41 additions & 8 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
test_cfg.device_eval_batch_size = 2
test_cfg.device_train_microbatch_size = 2

tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)
Expand Down Expand Up @@ -221,7 +225,11 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool,
else:
neo_cfg.model.name = 'hf_causal_lm'

tokenizer = build_tokenizer(neo_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(neo_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(neo_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model = COMPOSER_MODEL_REGISTRY[neo_cfg.model.name](neo_cfg.model,
tokenizer).to(device)
Expand Down Expand Up @@ -264,7 +272,11 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
t5_cfg.device = device
t5_cfg.max_seq_len = 16

tokenizer = build_tokenizer(t5_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(t5_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(t5_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model = COMPOSER_MODEL_REGISTRY[t5_cfg.model.name](t5_cfg.model,
tokenizer).to(device)
Expand Down Expand Up @@ -316,7 +328,11 @@ def test_determinism(attn_impl: str, precision: torch.dtype):
test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)
Expand Down Expand Up @@ -381,7 +397,11 @@ def test_loss_fn():

reproducibility.seed_all(test_cfg.get('global_seed', 42))

tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)
Expand Down Expand Up @@ -440,7 +460,11 @@ def test_opt_wrapping(prefixlm: bool):
}
config = DictConfig(conf)

tokenizer = build_tokenizer(config.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(config.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(config.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

if prefixlm:
model = ComposerHFPrefixLM(config.model, tokenizer)
Expand Down Expand Up @@ -1388,7 +1412,12 @@ def test_hf_init(tmp_path: pathlib.Path,
model = AutoModelForCausalLM.from_pretrained(save_path,
trust_remote_code=True)

tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

optimizer = DecoupledAdamW(model.parameters(),
lr=1e-5,
betas=tuple([0.9, 0.99]))
Expand Down Expand Up @@ -1435,7 +1464,11 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
)
test_cfg.device = torch.cuda.current_device()

tokenizer = build_tokenizer(test_cfg.tokenizer)
tokenizer_cfg: Dict[str,
Any] = om.to_container(test_cfg.tokenizer,
resolve=True) # type: ignore
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

mpt = MPTForCausalLM(hf_config)

Expand Down

0 comments on commit 186dd19

Please sign in to comment.