Skip to content

Commit

Permalink
Fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Linden Li committed Dec 5, 2023
2 parents 44c1500 + 61cd110 commit ca173a1
Show file tree
Hide file tree
Showing 136 changed files with 74,900 additions and 1,181 deletions.
8 changes: 8 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Require admin approval to modify all files in the root of the repository
# This includes setup.py, the README, and the CODEOWNERS file itself!
/* @mosaicml/composer-team-admins

# Require admin approval to change the CI build configuration
# All CI Changes should be reviewed for security
/.ci/ @mosaicml/composer-team-admins
/.github/ @mosaicml/composer-team-admins
11 changes: 7 additions & 4 deletions .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
type=int,
default=1800,
help='Timeout for run (in seconds)')
parser.add_argument('--deps_group',
type=str,
help='Dependency group to install')
args = parser.parse_args()

name = args.name
Expand Down Expand Up @@ -89,7 +92,7 @@
clear_tmp_path_flag = '-o tmp_path_retention_policy=none'
command += f'''
pip install --upgrade --user .[all]
pip install --upgrade --user .[{args.deps_group}]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}' {clear_tmp_path_flag}"
Expand Down Expand Up @@ -127,7 +130,7 @@
print(line, end='')

print('[GHA] Run completed. Waiting for run to finish...')
run = wait_for_run_status(run, status='completed')
run = wait_for_run_status(run, status=RunStatus.COMPLETED)

# Fail if command exited with non-zero exit code or timed out
assert run.status == RunStatus.COMPLETED
# Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED)
assert run.status == RunStatus.COMPLETED, f'Run did not complete: {run.status} ({run.reason})'
14 changes: 6 additions & 8 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,17 @@ jobs:
GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
if [ "${{ github.event_name }}" == "push" ]; then
echo "Triggered by push event."
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
elif [ "${{ github.event_name }}" == "pull_request" ]; then
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "Triggered by pull_request event."
STAGING_REPO="mosaicml/ci-staging"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache"
else
echo "Triggered by unknown event: ${{ github.event_name }}"
exit 1
# Triggered by push or workflow_dispatch event
echo "Triggered by ${{ github.event_name }} event, releasing to prod"
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
fi
echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ jobs:
strategy:
matrix:
include:
- name: 'cpu-latest'
container: mosaicml/pytorch:latest_cpu # mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
- name: 'cpu-1.13.1'
container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.1.0'
Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@ jobs:
uses: ./.github/workflows/pytest-gpu.yaml
strategy:
matrix:
# TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite
include:
- name: 'gpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
- name: 'gpu-1.13.1'
container: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all'
- name: 'gpu-2.1.0'
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all'
- name: 'gpu-2.1.0-flash2'
container: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all-flash2'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand All @@ -45,5 +43,6 @@ jobs:
pytest-command: ${{ matrix.pytest_command }}
pytest-markers: ${{ matrix.markers }}
python-version: 3.9
deps-group: ${{ matrix.deps_group }}
secrets:
mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }}
6 changes: 5 additions & 1 deletion .github/workflows/pytest-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ on:
required: false
type: string
default: 3.9
deps-group:
required: true
type: string
secrets:
mcloud-api-key:
required: true
Expand Down Expand Up @@ -77,4 +80,5 @@ jobs:
--image '${{ inputs.container }}' \
--pytest_markers '${{ inputs.pytest-markers }}' \
--pytest_command '${{ inputs.pytest-command }}' \
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS} \
--deps_group ${{ inputs.deps-group }}
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.3.0'
__version__ = '0.4.0'
32 changes: 12 additions & 20 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader

LOADER_NAME_TO_FUNCTION = {
'text': build_text_dataloader,
'text_denoising': build_text_denoising_dataloader,
'finetuning': build_finetuning_dataloader,
}


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
Expand All @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size)
134 changes: 82 additions & 52 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,46 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

__all__ = ['dataset_constructor']

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
if ('prompt' not in example) or ('response' not in example):
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise KeyError(
f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}'
)

if len(response_keys) != 1:
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {example=}.'
f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}'
)
if not isinstance(example['prompt'], str):

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
response = example[response_key]

if not isinstance(prompt, str):
raise TypeError(
f'Unable to tokenize example because "prompt" was not a string. {example=}'
f'Unable to tokenize example because {prompt_key} was not a string. {example=}'
)
if not isinstance(example['response'], str):

if not isinstance(response, str):
raise TypeError(
f'Unable to tokenize example because "response" was not a string. {example=}'
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)
return tokenizer(text=example['prompt'], text_target=example['response'])

return tokenizer(text=prompt, text_target=response)


class StreamingFinetuningDataset(StreamingDataset):
Expand Down Expand Up @@ -159,7 +180,6 @@ def __init__(self,
f'local directory {local} does not contain split {split}'
)

# Build Dataset
super().__init__(
local=local,
remote=remote,
Expand Down Expand Up @@ -345,51 +365,57 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
error: Optional[Exception] = None
filtered_dataset = None
try:
dataset = hf_datasets.load_dataset(dataset_name,
split=split,
**kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)
except Exception as e:
error = e
# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -403,7 +429,11 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

if error is not None:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')
assert filtered_dataset is not None
return filtered_dataset

def build_from_streaming(self, *args: Any,
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from composer.utils import using_torch_2
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -347,7 +348,7 @@ def profile_packing(
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.drop_last = False
dataloader_cfg.num_workers = 0
dataloader_cfg.prefetch_factor = None
dataloader_cfg.prefetch_factor = None if using_torch_2() else 2
dataloader_cfg.persistent_workers = False

# Determine the packing_ratio values we'll try
Expand Down
Loading

0 comments on commit ca173a1

Please sign in to comment.