Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Jun 27, 2023
2 parents 4b31ff0 + 38361a6 commit b38c503
Show file tree
Hide file tree
Showing 19 changed files with 2,329 additions and 8,772 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ name: 'CodeQL'
on:
push:
branches: [main]
pull_request:
# The branches below must be a subset of the branches above
branches: [main]
schedule:
- cron: '0 9 * * 1' # Every Monday at 09:00 (9:00 AM)

Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ jobs:
base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04

steps:
- name: Maximize Build Space on Worker
uses: easimon/maximize-build-space@v4
with:
overprovision-lvm: true
remove-dotnet: true
remove-android: true
remove-haskell: true

- name: Checkout
uses: actions/checkout@v3

Expand Down
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ You'll find in this repo:

# MPT

MPT-7B is a GPT-style model, and the first in the MosaicML Foundation Series of models. Trained on 1T tokens of a MosaicML-curated dataset, MPT-7B is open-source, commercially usable, and equivalent to LLaMa 7B on evaluation metrics. The MPT architecture contains all the latest techniques on LLM modeling -- Flash Attention for efficiency, Alibi for context length extrapolation, and stability improvements to mitigate loss spikes. The base model and several variants, including a 64K context length fine-tuned model (!!) are all available:
Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:


| Model | Context Length | Download | Demo | Commercial use? |
|--------------------|----------------|----------------------------------------------------|------------------------------------------------------------------|-----------------|
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-instruct) | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-storywriter)| Yes |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |

To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts.

Expand All @@ -71,6 +74,7 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
* [Blog: Introducing MPT-7B](https://www.mosaicml.com/blog/mpt-7b)
* [Blog: Benchmarking LLMs on H100](https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1)
* [Blog: Blazingly Fast LLM Evaluation](https://www.mosaicml.com/blog/llm-evaluation-for-icl)
Expand Down
53 changes: 25 additions & 28 deletions llmfoundry/callbacks/model_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,58 +6,54 @@

"""Monitor gradients during training."""

from enum import Enum
import math
import re
from enum import Enum
from typing import Optional

import torch
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

__all__ = ['MoedlGauntlet']
__all__ = ['ModelGauntlet']


class Weighting(Enum):
EQUAL = 1
SAMPLE_SZ = 2
LOG_SAMPLE_SZ = 3
class MoedlGauntlet(Callback):

def __init__(
self,
logger_keys: dict,
tasks: dict,
weighting: Weighting = Weighting.EQUAL,
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
benchmark_sizes: Optional[dict] = None
):


class ModelGauntlet(Callback):

def __init__(self,
logger_keys: dict,
tasks: dict,
weighting: Weighting = Weighting.EQUAL,
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
benchmark_sizes: Optional[dict] = None):
self.tasks = tasks
self.weighting = Weighting[weighting]
self.subtract_random_baseline = subtract_random_baseline
self.rescale_accuracy = rescale_accuracy
self.logger_keys = logger_keys
for category in self.tasks:

for benchmark in category['benchmarks']:
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
cumulative_samples = max(
sum(count for name,count in benchmark_sizes.items() if name.startswith(bench_name)),
1)
sum(count for name, count in benchmark_sizes.items()
if name.startswith(bench_name)), 1)

if self.weighting == Weighting.EQUAL:
weight = 1
elif self.weighting == Weighting.SAMPLE_SZ:
weight = cumulative_samples
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
weight = max(
math.log(cumulative_samples, 2),
1
)
weight = max(math.log(cumulative_samples, 2), 1)

benchmark['weighting'] = weight

def compute_averages(self, logger_data):

results = {}
Expand Down Expand Up @@ -106,10 +102,10 @@ def eval_end(self, state: State, logger: Logger):
score = new_metrics[matching_key[0]]

if self.subtract_random_baseline:
score -= benchmark['scorecard']['random_baseline']
score -= benchmark['random_baseline']

if self.rescale_accuracy and self.subtract_random_baseline:
score /= 1.0 - benchmark['scorecard']['random_baseline']
score /= 1.0 - benchmark['random_baseline']

composite_scores[category['name']].append({
'name': benchmark['name'],
Expand All @@ -123,10 +119,11 @@ def eval_end(self, state: State, logger: Logger):
for k in composite_scores[category['name']])

composite_scores = {
f'metrics/icl_taxonomy/{k}': v for k, v in composite_scores.items()
f'metrics/model_gauntlet/{k}': v
for k, v in composite_scores.items()
}

composite_scores['metrics/icl_taxonomy/average'] = sum(
composite_scores['metrics/model_gauntlet/average'] = sum(
composite_scores.values()) / len(composite_scores.values())
logger.log_metrics(composite_scores)

Expand Down
25 changes: 20 additions & 5 deletions llmfoundry/callbacks/monolithic_ckpt_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import torch
from composer.core import Callback, State
from composer.core.state import fsdp_state_dict_type_context
from composer.core.state import (fsdp_get_optim_state_dict,
fsdp_state_dict_type_context)
from composer.loggers import Logger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.utils import (dist, format_name_with_dist_and_time, parse_uri,
Expand Down Expand Up @@ -79,13 +80,27 @@ def _save_checkpoint(self, state: State, logger: Logger):
'state': state.state_dict(),
'rng': reproducibility.get_rng_state()
}
if not self.keep_optimizers:
state_dict['state'].pop('optimizers')
# Remove sharded model and optimizer state dicts
state_dict['state'].pop('optimizers')
state_dict['state'].pop('model')

# Add in unsharded model params.
with fsdp_state_dict_type_context(state.model,
state_dict_type='full'):
state_dict['state']['model'] = state.model.state_dict()
if dist.get_global_rank() == 0:
torch.save(state_dict, save_path)

# Add in unsharded optimizer state dict.
if self.keep_optimizers:
optimizer = state.optimizers[0]
state_dict['state']['optimizers'] = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(state.model,
optimizer,
state_dict_type='full')
}
if dist.get_global_rank() == 0:
torch.save(state_dict, save_path)

if self.upload_to_object_store and self.remote_ud is not None and dist.get_global_rank(
) == 0:
remote_file_name = str(Path(save_dir) / Path(filename))
Expand Down
54 changes: 51 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import tempfile
from typing import Union

import torch
from composer.utils import dist
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -38,7 +40,9 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
---
*** HuggingFace dataset config fields ***
cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset
to use.
to use. Can also be a remote http(s) directory or object store bucket
containing the file {split}.jsonl in the format (prompt, response),
in which case the builder will create a HuggingFace dataset.
cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to
pass to `datasets.load_dataset`, which can be used to load
a dataset from local files.
Expand Down Expand Up @@ -145,7 +149,51 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
)

else:
dataset = dataset_constructor.build_from_hf(cfg.dataset, tokenizer)
backend, _, _ = parse_uri(cfg.dataset.hf_name)
if backend not in ['', None]:
if cfg.dataset.get('split') is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
supported_extensions = ['jsonl', 'csv', 'parquet']
with tempfile.TemporaryDirectory() as tmp_dir:
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
destination = str(
os.path.abspath(
f'{tmp_dir}/{cfg.dataset.split}.{extension}'))
try:
with dist.run_local_rank_zero_first():
get_file(name, destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
raise FileNotFoundError(
f'Could not find a {cfg.dataset.split} file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'at {cfg.dataset.hf_name}/{cfg.dataset.split}'
) from e
else:
print(
f'Could not find {name}, looking for another extension'
)
continue
# 'json' causes special behavior in the dataset constructor
cfg.dataset.hf_name = extension if extension != 'jsonl' else 'json'
kwargs = cfg.dataset.get('hf_kwargs', {})
kwargs['data_files'] = destination
cfg.dataset['hf_kwargs'] = kwargs
print(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
tokenizer=tokenizer,
)
break
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
tokenizer=tokenizer,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
Expand Down
96 changes: 46 additions & 50 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,57 +201,53 @@ def _validate_cfg(icl_cfg):
icl_cfg.batch_size = default_batch_size

for icl_cfg in icl_tasks:
try:
_validate_cfg(icl_cfg)
for num_fewshot in list(icl_cfg.num_fewshot):
if tokenizer.pad_token_id is None:
# Current workaround to support GPT2 tokenizer with `pad_token_id = None`
pad_tok_id = tokenizer.eos_token_id
else:
pad_tok_id = tokenizer.pad_token_id
label = f'{icl_cfg.label}/{num_fewshot}-shot'
metric_names = list(icl_cfg.metric_names)
# TODO: fix Composer bug when copying local paths and destination exists
destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
if dist.get_local_rank() == 0 and os.path.exists(
destination_path):
os.remove(destination_path)
dist.barrier()
dataloaders = get_icl_task_dataloader(
icl_cfg.icl_task_type,
icl_cfg.dataset_uri,
tokenizer,
batch_size=icl_cfg.batch_size,
max_seq_len=icl_cfg.max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=icl_cfg.prompt_string,
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(icl_cfg, 'has_categories'
) and icl_cfg.has_categories and isinstance(
dataloaders, dict):
for category in dataloaders.keys():
logger_keys.extend([
f'metrics/{label}/{category}/{m}'
for m in metric_names
])
evaluators.append(
Evaluator(label=f'{label}/{category}',
dataloader=dataloaders[category],
metric_names=metric_names),)
else:
logger_keys.extend(
[f'metrics/{label}/{m}' for m in metric_names])
_validate_cfg(icl_cfg)
for num_fewshot in list(icl_cfg.num_fewshot):
if tokenizer.pad_token_id is None:
# Current workaround to support GPT2 tokenizer with `pad_token_id = None`
pad_tok_id = tokenizer.eos_token_id
else:
pad_tok_id = tokenizer.pad_token_id
label = f'{icl_cfg.label}/{num_fewshot}-shot'
metric_names = list(icl_cfg.metric_names)
# TODO: fix Composer bug when copying local paths and destination exists
destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
os.remove(destination_path)
dist.barrier()

dataloaders = get_icl_task_dataloader(
icl_cfg.icl_task_type,
icl_cfg.dataset_uri,
tokenizer,
batch_size=icl_cfg.batch_size,
max_seq_len=icl_cfg.max_seq_len,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=icl_cfg.prompt_string,
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(
icl_cfg,
'has_categories') and icl_cfg.has_categories and isinstance(
dataloaders, dict):
for category in dataloaders.keys():
logger_keys.extend([
f'metrics/{label}/{category}/{m}' for m in metric_names
])
evaluators.append(
Evaluator(label=label,
dataloader=dataloaders,
Evaluator(label=f'{label}/{category}',
dataloader=dataloaders[category],
metric_names=metric_names),)
except Exception as e:
print(f'Got exception: {str(e)} while building ICL task: {icl_cfg}')
raise e
else:
logger_keys.extend(
[f'metrics/{label}/{m}' for m in metric_names])
evaluators.append(
Evaluator(label=label,
dataloader=dataloaders,
metric_names=metric_names),)

return evaluators, logger_keys
Loading

0 comments on commit b38c503

Please sign in to comment.