Skip to content

Commit

Permalink
Merge branch 'main' into safe-load
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 18, 2023
2 parents 2261ae4 + 06b9a1f commit 7a06abd
Show file tree
Hide file tree
Showing 16 changed files with 288 additions and 70 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/smoketest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Smoketest
on:
push:
branches:
- main
- release/*
pull_request:
branches:
- main
- release/*
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
defaults:
run:
working-directory: .
jobs:
smoketest:
runs-on: ubuntu-20.04
timeout-minutes: 10
strategy:
matrix:
python_version:
- "3.9"
- "3.10"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Setup
run: |
set -ex
python -m pip install --upgrade 'pip<23' wheel
python -m pip install --upgrade .
python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1
- name: Run checks
run: |
pytest tests/test_smoketest.py
20 changes: 20 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
import torch

try:
import warnings

# bitsandbytes is a very noisy library. A lot of it is print statements that we can't easily suppress,
# but we can at least suppress a bunch of spurious warnings.
warnings.filterwarnings('ignore',
category=UserWarning,
module='bitsandbytes')

import logging

from llmfoundry.utils.logging_utils import SpecificWarningFilter

# Filter out Hugging Face warning for not using a pinned revision of the model
hf_dynamic_modules_logger = logging.getLogger(
'transformers.dynamic_module_utils')
new_files_warning_filter = SpecificWarningFilter(
'A new version of the following files was downloaded from')

hf_dynamic_modules_logger.addFilter(new_files_warning_filter)

# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
Expand Down
16 changes: 15 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter

log = logging.getLogger(__name__)

__all__ = ['dataset_constructor']
Expand Down Expand Up @@ -245,7 +247,7 @@ def wrapper(func: Callable) -> Callable:

def print_registered_tasks(self) -> None:
tasks = sorted(self._task_preprocessing_registry.keys())
print('\n'.join(tasks))
log.info('\n'.join(tasks))

def get_preprocessing_fn_from_dict(
self,
Expand Down Expand Up @@ -363,6 +365,15 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

hf_tokenization_logger = logging.getLogger(
'transformers.tokenization_utils_base')
sequence_length_warning_filter = SpecificWarningFilter(
'Token indices sequence length is longer than the specified maximum sequence length'
)

# We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face
hf_tokenization_logger.addFilter(sequence_length_warning_filter)

error: Optional[Exception] = None
filtered_dataset = None
try:
Expand Down Expand Up @@ -468,6 +479,9 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')

hf_tokenization_logger.removeFilter(sequence_length_warning_filter)

assert filtered_dataset is not None
return filtered_dataset

Expand Down
28 changes: 18 additions & 10 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
Expand Down Expand Up @@ -102,20 +102,27 @@ def __init__(self, om_model_config: Union[DictConfig,
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=requested_attention_implementation,
use_cache=
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)

# This is not how you are supposed to set this, but transformers currently only
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we check
# whether it is installed above, and whether the chosen config supports it here.
# https://github.com/huggingface/transformers/issues/26878
config._flash_attn_2_enabled = use_flash_attention_2
# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, config, *args, **kwargs): # type: ignore
config._attn_implementation = requested_attention_implementation
return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
Expand Down Expand Up @@ -184,7 +191,8 @@ def __init__(self, om_model_config: Union[DictConfig,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
config=config)
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
config,
Expand Down
14 changes: 8 additions & 6 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
]
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

if TYPE_CHECKING:
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

MAX_RETRIES = 10

Expand Down Expand Up @@ -99,7 +101,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
'role':
'system',
'content':
model_cfg.get('sytsem_role_prompt',
model_cfg.get('system_role_prompt',
'Please complete the following text: ')
}, {
'role': 'user',
Expand Down Expand Up @@ -201,7 +203,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def process_result(self, completion: Optional[ChatCompletion]):
def process_result(self, completion: Optional['ChatCompletion']):
if completion is None:
raise ValueError("Couldn't generate model output")

Expand Down Expand Up @@ -234,7 +236,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
logprobs=5,
temperature=0.0)

def process_result(self, completion: Optional[Completion]):
def process_result(self, completion: Optional['Completion']):
if completion is None:
raise ValueError("Couldn't generate model output")

Expand Down
45 changes: 43 additions & 2 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""MPT Blocks used for the MPT Model."""

import logging
from typing import Any, Optional, Union
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -18,6 +20,36 @@

log = logging.getLogger(__name__)

_FFN_ACT_FN_DEFAULT = {
'name': 'gelu',
'approximate': 'none',
}


def resolve_ffn_act_fn(
config: Optional[dict] = None,) -> Callable[[torch.Tensor], torch.Tensor]:
"""Resolve the activation function for the feed-forward network.
Args:
config (Optional[dict]): The configuration dictionary for the activation function.
The dict config must specify the 'name' of a torch.nn.functional activation
function. All of other key values pairs are bound to the function as a partial.
Returns:
Callable[[torch.Tensor], torch.Tensor]: The activation function.
"""
if config is None:
config = _FFN_ACT_FN_DEFAULT
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)


_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)


def resolve_ffn_hidden_size(
d_model: int,
Expand Down Expand Up @@ -55,6 +87,7 @@ def __init__(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
bias: bool = True,
):
Expand All @@ -72,7 +105,7 @@ def __init__(
ffn_hidden_size,
**self.fc_kwargs,
)
self.act = nn.GELU(approximate='none')
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
Expand All @@ -92,6 +125,7 @@ def __init__(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
bias: bool = True,
):
Expand All @@ -100,6 +134,7 @@ def __init__(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
ffn_hidden_size=ffn_hidden_size,
act_fn=act_fn,
device=device,
bias=bias,
)
Expand Down Expand Up @@ -128,6 +163,7 @@ def build_ffn(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
Expand All @@ -142,6 +178,7 @@ def build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
fc_type=fc_type,
act_fn=resolve_ffn_act_fn(ffn_act_fn),
ffn_hidden_size=ffn_hidden_size,
device=device,
bias=bias,
Expand All @@ -150,6 +187,10 @@ def build_ffn(
assert te is not None
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio,
ffn_hidden_size)
if ffn_act_fn is not None:
raise ValueError(
f'Transformer Engine block does not support custom activation functions.'
)
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=ffn_hidden_size,
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def _validate_config(self) -> None:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
if 'ffn_act_fn' in self.ffn_config.keys():
raise ValueError(
f'Transformer Engine block does not support custom activation functions.'
)
if not self.use_pad_tok_in_ffn:
try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down
21 changes: 21 additions & 0 deletions llmfoundry/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging


class SpecificWarningFilter(logging.Filter):

def __init__(self, message_to_suppress: str):
"""Filter out a specific warning message based on its content.
This can be useful for filtering out specific warning messages from third party packages.
Args:
message_to_suppress (str): The warning message to suppress.
"""
super().__init__()
self.message_to_suppress = message_to_suppress

def filter(self, record: logging.LogRecord) -> bool:
return self.message_to_suppress not in record.getMessage()
Loading

0 comments on commit 7a06abd

Please sign in to comment.