Skip to content

Commit

Permalink
Merge branch 'main' into bert-v0
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobfulano committed Aug 21, 2023
2 parents 5e8d68d + 6670068 commit 3b59aa4
Show file tree
Hide file tree
Showing 9 changed files with 392 additions and 241 deletions.
2 changes: 1 addition & 1 deletion .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
image=args.image,
integrations=[git_integration],
command=command,
scheduling={'max_duration:': args.timeout / 60 / 60},
scheduling={'max_duration': args.timeout / 60 / 60},
)

# Create run
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
build_optimizer, build_scheduler,
build_tokenizer)
from llmfoundry.utils.config_utils import (calculate_batch_size_info,
log_config,
log_config, pop_config,
update_batch_size_info)
except ImportError as e:
raise ImportError(
Expand All @@ -25,4 +25,5 @@
'calculate_batch_size_info',
'update_batch_size_info',
'log_config',
'pop_config',
]
57 changes: 19 additions & 38 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,48 +88,29 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]):
raise ValueError(f'Not sure how to build algorithm: {name}')


def build_optimizer(cfg: DictConfig, model: torch.nn.Module):
if cfg.name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
eps=cfg.eps,
weight_decay=cfg.weight_decay)
elif cfg.name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
weight_decay=cfg.weight_decay)
elif cfg.name == 'clip_lion':
return DecoupledClipLion(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
weight_decay=cfg.weight_decay,
outlier_threshold=cfg.outlier_threshold)
elif cfg.name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
weight_decay=cfg.weight_decay,
outlier_threshold=cfg.outlier_threshold,
timeout=cfg.timeout,
lr_penalty=cfg.lr_penalty,
min_scale=cfg.min_scale)
def build_optimizer(model: torch.nn.Module, name: str,
optimizer_config: Dict[str, Any]):
if name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(), **optimizer_config)
elif name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(), **optimizer_config)
elif name == 'clip_lion':
return DecoupledClipLion(model.parameters(), **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {cfg.name}')
raise ValueError(f'Not sure how to build optimizer: {name}')


def build_scheduler(cfg: DictConfig):
if cfg.name == 'constant_with_warmup':
return ConstantWithWarmupScheduler(t_warmup=cfg.t_warmup)
elif cfg.name == 'cosine_with_warmup':
return CosineAnnealingWithWarmupScheduler(t_warmup=cfg.t_warmup,
alpha_f=cfg.alpha_f)
elif cfg.name == 'linear_decay_with_warmup':
return LinearWithWarmupScheduler(t_warmup=cfg.t_warmup,
alpha_f=cfg.alpha_f)
def build_scheduler(name: str, scheduler_config: Dict[str, Any]):
if name == 'constant_with_warmup':
return ConstantWithWarmupScheduler(**scheduler_config)
elif name == 'cosine_with_warmup':
return CosineAnnealingWithWarmupScheduler(**scheduler_config)
elif name == 'linear_decay_with_warmup':
return LinearWithWarmupScheduler(**scheduler_config)
else:
raise ValueError(f'Not sure how to build scheduler: {cfg.name}')
raise ValueError(f'Not sure how to build scheduler: {name}')


def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase:
Expand Down
18 changes: 14 additions & 4 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.models.utils import init_empty_weights
Expand All @@ -16,14 +16,24 @@
def pop_config(cfg: DictConfig,
key: str,
must_exist: bool = True,
default_value: Any = None) -> Any:
default_value: Any = None,
convert: bool = False) -> Any:
"""Pop a value from the main config file and return it.
If the key does not exist, return the default_value or raise a RuntimeError
depending on the must_exist flag.
depending on the must_exist flag. If the convert flag is set to True, then
we will convert the value to a python object using OmegaConf.to_container.
"""
value = cfg.pop(key, None)
if value is not None:
if value is not None and convert:
if not isinstance(value, DictConfig) and not isinstance(
value, ListConfig):
raise ValueError(
f'The key: {key} has a value: {value} that cannot be \
converted to a dict or list. Please check your yaml.'
)
return om.to_container(value)
elif value is not None:
return value
elif must_exist:
raise NameError(
Expand Down
152 changes: 152 additions & 0 deletions mcli/mcli-llama2-finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_commit: 148c0793907a6afa48a892620e637ef5f90cdaf1 # TODO: repin this after next release
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo

command: |
cd llm-foundry/scripts
composer train/train.py /mnt/config/parameters.yaml
image: mosaicml/llm-foundry:1.13.1_cu117-latest
name: llama2-finetune

compute:
# Note: Finetuning the 70b model requires at least 16x80GB GPUs
gpus: 8 # Number of GPUs to use
## These configurations are optional
# cluster: TODO # Name of the cluster to use for this run
# gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments

# The below is injected as a YAML file: /mnt/config/parameters.yaml
parameters:
tokenizer_name: meta-llama/Llama-2-7b-hf
max_seq_len: 4096
global_seed: 17

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

# IMPORTANT: Uncomment if using the 70b model
# max_split_size_mb: 512

# Model
model:
name: hf_causal_lm
init_device: mixed
pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
pretrained: true
# Note: you must have set the HUGGING_FACE_HUB_TOKEN environment variable and have access to the llama2 models
use_auth_token: true
attention_patch_type: triton

# Tokenizer
tokenizer:
name: ${tokenizer_name}
kwargs:
model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
name: finetuning
dataset:
hf_name: mosaicml/dolly_hhrlhf
split: train
max_seq_len: ${max_seq_len}
allow_pad_trimming: false
decoder_only_format: true
shuffle: true
# # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...`
# # to profile this run's optimal packing_ratio as it depends on GPU count,
# # batch size, sequence length
# packing_ratio:
drop_last: true
num_workers: 8
pin_memory: false
prefetch_factor: 2
persistent_workers: true
timeout: 0

eval_loader:
name: finetuning
dataset:
hf_name: mosaicml/dolly_hhrlhf
split: test
max_seq_len: ${max_seq_len}
allow_pad_trimming: false
decoder_only_format: true
# packing_ratio:
shuffle: false
drop_last: true
num_workers: 8
pin_memory: false
prefetch_factor: 2
persistent_workers: true
timeout: 0

# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

# Note: You may want to change learning rate, betas, weight decay
optimizer:
name: decoupled_lionw
lr: 5.0e-7
betas:
- 0.9
- 0.95
weight_decay: 0.0

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 1ep
eval_first: false
eval_interval: 1ep
eval_subset_num_batches: -1
global_train_batch_size: 64

# System
seed: ${global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: auto
precision: amp_bf16

# FSDP
fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: PURE
activation_checkpointing: true
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}

# loggers:
# wandb: {}

# Checkpoint to local filesystem or remote object store
# save_interval: 2000ba
# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from local filesystem or remote object store
# load_path: ./gpt-1b/checkpoints/latest-rank{rank}.pt
# load_path: s3://my-bucket/my-folder/gpt-1b/checkpoints/latest-rank{rank}.pt
Loading

0 comments on commit 3b59aa4

Please sign in to comment.