Skip to content

Commit

Permalink
Refactor logging (#234)
Browse files Browse the repository at this point in the history
Replaces most print statements with proper logging. Deprecates the `verbose` argument in favor of using the `python_log_level` argument that is also used by composer.
  • Loading branch information
hanlint committed Sep 13, 2023
1 parent f03276d commit 0fdf43f
Show file tree
Hide file tree
Showing 18 changed files with 111 additions and 172 deletions.
12 changes: 7 additions & 5 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Aggregate ICL evals into composite scores."""

import logging
import math
from enum import Enum
from typing import Optional
Expand All @@ -12,6 +13,8 @@

__all__ = ['EvalGauntlet']

log = logging.getLogger(__name__)


class Weighting(Enum):
EQUAL = 1
Expand Down Expand Up @@ -130,9 +133,8 @@ def eval_after_all(self, state: State, logger: Logger):
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"

if key not in new_metrics:
print(
f"Warning: couldn't find results for benchmark: {benchmark}"
)
log.warning(
f'Could not find results for benchmark: {benchmark}.')
missing_metrics.append(key)
else:
score = new_metrics[key]
Expand All @@ -150,8 +152,8 @@ def eval_after_all(self, state: State, logger: Logger):
})

if len(missing_metrics) > 0:
print(
f"Removing category `{category['name']}` from gauntlet scores because benchmarks were missing: {missing_metrics}"
log.warning(
f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
)
del composite_scores[category['name']]
continue
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/callbacks/resumption_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List

from composer.core import Callback, State
Expand All @@ -11,6 +12,8 @@
'LayerFreezing',
]

log = logging.getLogger(__name__)


class GlobalLRScaling(Callback):
"""GlobalLRScaling.
Expand Down Expand Up @@ -38,7 +41,7 @@ def fit_start(self, state: State, logger: Logger):
group['weight_decay'] = group['lr'] * self.wd_pct
if 'initial_lr' in group:
group['initial_lr'] *= self.lr_scale
print(
log.info(
f"Set LR and WD to {group['lr']}, {group['weight_decay']}")

for scheduler in state.schedulers:
Expand Down Expand Up @@ -74,7 +77,7 @@ def fit_start(self, state: State, logger: Logger):
for name, p in state.model.named_parameters():
if p.requires_grad and name in self.layer_names:
p.requires_grad = False
print(f'Froze layer: {name}\nParam: {p}')
log.debug(f'Froze layer: {name}\nParam: {p}')
successful_freeze = True

if not successful_freeze:
Expand Down
1 change: 0 additions & 1 deletion llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset,

def __iter__(self) -> Iterable[Dict[str, bytes]]:
for sample in self.hf_dataset:
# print(sample)
# convert to bytes to store in MDS binary format
yield {'text': sample['text'].encode('utf-8')}

Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _build_hf_dataset_from_remote(
f'at {files_searched}'
) from e
else:
print(
log.debug(
f'Could not find {name}, looking for another extension')
continue

Expand All @@ -343,7 +343,7 @@ def _build_hf_dataset_from_remote(
dist.barrier()

cfg.dataset.hf_name = finetune_dir
print(cfg.dataset)
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
Expand Down
34 changes: 13 additions & 21 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
"""

import importlib
import logging
import os
import warnings
from typing import Any, Callable, Dict, Optional, Union
Expand All @@ -41,6 +42,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)

__all__ = ['dataset_constructor']


Expand Down Expand Up @@ -205,16 +208,14 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:

def get_preprocessing_fn_from_str(self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None,
verbose: bool = False):
dataset_name: Optional[str] = None):
"""Get a preprocessing function from a string.
String can be either a registered function or an import path.
Args:
preprocessor (Optional[str]): The name of the preprocessing function, or an import path.
dataset_name (Optional[str]): The dataset name to look up in the registry.
verbose (bool): Whether to print verbose messages or not.
Returns:
Callable: The preprocessing function or None if not found.
Expand All @@ -226,33 +227,24 @@ def get_preprocessing_fn_from_str(self,
if dataset_name is None:
return None
if dataset_name in self._task_preprocessing_registry:
if verbose:
print(
f'Re-formatting dataset with "{dataset_name}" preprocessing function.'
)
log.info(
f'Re-formatting dataset with "{dataset_name}" preprocessing function.'
)
return self._task_preprocessing_registry[dataset_name]
else:
if verbose:
print(
'No preprocessor was supplied and no preprocessing function ' +\
log.info('No preprocessor was supplied and no preprocessing function ' +\
f'is registered for dataset name "{dataset_name}". No additional ' +\
'preprocessing will be applied. If the dataset is already formatted ' +\
'correctly, you can ignore this message.'
)
'correctly, you can ignore this message.')
return None
if preprocessor in self._task_preprocessing_registry:
if verbose:
print(
f'Re-formatting dataset with "{preprocessor}" preprocessing function.'
)
log.info(
f'Re-formatting dataset with "{preprocessor}" preprocessing function.'
)
return self._task_preprocessing_registry[preprocessor]

try:
import_path, function_name = preprocessor.split(':', maxsplit=1)
if verbose:
print(
f'Importing preprocessing function via: `from {import_path} import {function_name}`'
)
module = importlib.import_module(import_path)
preprocessing_fn = getattr(module, function_name)
except Exception as e:
Expand Down Expand Up @@ -289,7 +281,7 @@ def build_from_hf(
proto_preprocessing_fn)
else:
preprocessing_fn = self.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name, verbose=True)
proto_preprocessing_fn, dataset_name)

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

Expand Down
5 changes: 4 additions & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""

import logging
import os
from typing import Mapping, Union

Expand Down Expand Up @@ -35,6 +36,8 @@

__all__ = ['ComposerHFCausalLM']

log = logging.getLogger(__name__)


class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
Expand Down Expand Up @@ -185,7 +188,7 @@ def __init__(self, om_model_config: Union[DictConfig,
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

print(
log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import \
Expand Down
18 changes: 0 additions & 18 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def __init__(
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__()
Expand Down Expand Up @@ -476,21 +475,8 @@ def __init__(
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
if verbose:
warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
'it uses more memory. When training larger models this can trigger ' +\
'alloc retries which hurts performance. If encountered, we recommend ' +\
'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
)
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
'`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
'we recommend using `attn_impl: triton`.'
)
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

Expand Down Expand Up @@ -569,7 +555,6 @@ def __init__(
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__(
Expand All @@ -583,7 +568,6 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
verbose=verbose,
device=device)


Expand All @@ -605,7 +589,6 @@ def __init__(
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__(
Expand All @@ -619,7 +602,6 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
verbose=verbose,
device=device)


Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
ffn_config: Optional[Dict] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
verbose: int = 0,
fc_type: str = 'torch',
device: Optional[str] = None,
**kwargs: Any,
Expand Down Expand Up @@ -70,7 +69,6 @@ def __init__(
self.attn = attn_class(d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
verbose=verbose,
device=device,
**attn_config_subset_for_attn_class)
self.norm_2 = None
Expand Down
9 changes: 7 additions & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __init__(
init_device: str = 'cpu',
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
verbose: int = 0,
embedding_fraction: float = 1.0,
norm_type: str = 'low_precision_layernorm',
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
verbose: Optional[int] = None,
**kwargs: Any,
):
"""The MPT configuration class.
Expand Down Expand Up @@ -135,12 +135,17 @@ def __init__(
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.use_cache = use_cache
self.init_config = init_config
self.fc_type = fc_type
if verbose is not None:
warnings.warn(
DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
))

if 'name' in kwargs:
del kwargs['name']
if 'loss_fn' in kwargs:
Expand Down
26 changes: 10 additions & 16 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
pass
# isort: on

import logging

log = logging.getLogger(__name__)


class MPTPreTrainedModel(PreTrainedModel):
config_class = MPTConfig
Expand Down Expand Up @@ -118,8 +122,8 @@ def __init__(self, config: MPTConfig):
self.norm_f = norm_class(config.d_model, device=config.init_device)

if config.init_device != 'meta':
print(
f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
log.info(
f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.'
)
self.apply(self.param_init_fn)

Expand All @@ -142,19 +146,11 @@ def __init__(self, config: MPTConfig):
for module in self.modules():
if hasattr(module, 'bias') and isinstance(
module.bias, nn.Parameter):
if config.verbose:
warnings.warn(
f'Removing bias ({module.bias}) from {module}.')
log.info(f'Removing bias ({module.bias}) from {module}.')
module.register_parameter('bias', None)

# Print verbose info
if config.verbose and config.verbose > 2:
print(self)
if 'verbose' not in self.config.init_config:
self.config.init_config['verbose'] = self.config.verbose
if self.config.init_config['verbose'] > 1:
init_fn_name = self.config.init_config['name']
warnings.warn(f'Using {init_fn_name} initialization.')
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

def get_input_embeddings(self):
return self.wte
Expand Down Expand Up @@ -486,7 +482,7 @@ def __init__(self, config: MPTConfig):
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

print(f'Instantiating an MPTForCausalLM model from {__file__}')
log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)

Expand Down Expand Up @@ -717,8 +713,6 @@ def __init__(
from flash_attn.losses.cross_entropy import \
CrossEntropyLoss as FusedCrossEntropyLoss

if hf_config.verbose > 1:
warnings.warn('Using Fused Cross Entropy Loss.')
self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
except:
raise ValueError(
Expand Down
Loading

0 comments on commit 0fdf43f

Please sign in to comment.