Skip to content

Commit

Permalink
Adding pyright to pre-commit (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcui19 committed Aug 2, 2023
1 parent 05c6055 commit 9250e84
Show file tree
Hide file tree
Showing 63 changed files with 840 additions and 647 deletions.
1 change: 0 additions & 1 deletion .github/workflows/code-quality.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ jobs:
strategy:
matrix:
python_version:
- '3.8'
- '3.9'
- '3.10'
pip_deps:
Expand Down
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ repos:
entry: yamllint
language: python
types: [file, yaml]
- repo: local
hooks:
- id: pyright
name: pyright
entry: pyright
language: node
types: [python]
pass_filenames: false
args: [--warnings]
additional_dependencies: ["[email protected]"]
- repo: https://github.com/trufflesecurity/trufflehog.git
rev: v3.40.0
hooks:
Expand Down
6 changes: 4 additions & 2 deletions llmfoundry/callbacks/fdiff_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@


class FDiffMetrics(Callback):
"""Rate of chage of metrics.
"""Rate of change of metrics.
tracks and plots the rate of change of metrics effectively taking the
numerical derivative of the metrics
"""

def __init__(self, diff_train_metrics=False, diff_eval_metrics=True):
def __init__(self,
diff_train_metrics: bool = False,
diff_eval_metrics: bool = True):
self.diff_train_metrics = diff_train_metrics
self.diff_eval_metrics = diff_eval_metrics

Expand Down
13 changes: 10 additions & 3 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations to wandb from a set of prompts."""
from typing import List, Union, cast
from typing import Any, List, Union, cast

import torch
import wandb
Expand All @@ -16,7 +16,8 @@

class Generate(Callback):

def __init__(self, prompts: List[str], batch_log_interval: int, **kwargs):
def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):
"""Periodically log generations to wandb from a set of prompts.
In the main view for a run, there will be a table that will show the _last_ logged generations.
Expand Down Expand Up @@ -57,6 +58,11 @@ def generate(self, state: State, logger: Logger):
tokenizer = cast(Tokenizer, state.model.tokenizer)
device = state.device

if not hasattr(model.model, 'generate'):
raise ValueError(
f'Cannot generate from model {model.model.__class__.__name__} because it does not have a `generate` method'
)

# stash the original original value of padding_side because generation requires left padding
original_padding_side = tokenizer.padding_side
tokenizer.padding_side = 'left'
Expand All @@ -74,9 +80,10 @@ def generate(self, state: State, logger: Logger):
dummy_input = device.tensor_to_device(dummy_input)
with get_precision_context(state.precision):
with torch.no_grad():
assert isinstance(model.model, torch.nn.Module)
_ = model.model(input_ids=dummy_input)

output_token_ids = model.model.generate(
output_token_ids = model.model.generate( # type: ignore
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
Expand Down
28 changes: 20 additions & 8 deletions llmfoundry/callbacks/model_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class ModelGauntlet(Callback):
weighting (Weighting): The weighting scheme used to balance different tasks within each category.
Either assign them all equal weight, assign them weight proportional
to the dataset size, or assign them weight proportional to the log2 of the dataset size.
substract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'.
subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
from the performance on each individual benchmark before aggregating.
rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
Expand All @@ -45,7 +46,7 @@ class ModelGauntlet(Callback):
def __init__(self,
logger_keys: dict,
categories: dict,
weighting: Weighting = Weighting.EQUAL,
weighting: str = 'EQUAL',
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
benchmark_sizes: Optional[dict] = None):
Expand All @@ -69,27 +70,38 @@ def __init__(self,

for benchmark in category['benchmarks']:
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
cumulative_samples = max(
sum(count for name, count in benchmark_sizes.items()
if name.startswith(bench_name)), 1)

if self.weighting != Weighting.EQUAL:
assert benchmark_sizes is not None
cumulative_samples = max(
sum(count for name, count in benchmark_sizes.items()
if name.startswith(bench_name)), 1)
else:
cumulative_samples = -1 # pyright

weight = None
if self.weighting == Weighting.EQUAL:
weight = 1
elif self.weighting == Weighting.SAMPLE_SZ:
weight = cumulative_samples
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
weight = max(math.log(cumulative_samples, 2), 1)

assert weight is not None
benchmark['weighting'] = weight

def compute_averages(self, logger_data):
def compute_averages(self, logger_data: Logger):

results = {}
pat = re.compile(
'metrics/(.*?)/(\d+)-shot(/.*?)?/InContextLearning(.*)')
'metrics/(.*?)/(\d+)-shot(/.*?)?/InContextLearning(.*)' # type: ignore
)
for key in self.logger_keys:
match = pat.match(key)
val = logger_data.data[key][0][1].item()

# TODO(bmosaicml) This needs to be factored for this callback to work as a normal callback
# and therefore for the typing to be fixed
val = logger_data.data[key][0][1].item() # type: ignore

if match:
eval_name = match.group(1)
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/callbacks/monolithic_ckpt_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
save_path = str(Path(temp_save_dir) / Path(filename))
save_path = str(
Path(temp_save_dir) / # type: ignore
Path(filename))
dirname = os.path.dirname(save_path)
if dirname:
os.makedirs(dirname, exist_ok=True)
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
'eos_text' if eos_text_provided else 'bos_text')
warnings.warn(
f'The provided tokenizer adds special tokens, but you also specified {message}. This may result '
+
'in duplicated special tokens. Please be sure this is what you intend.'
)

Expand Down
26 changes: 14 additions & 12 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import StreamingTextDataset
Expand All @@ -26,16 +26,15 @@
# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

# Required signature of any `prefix_function` (see below)
PREFIX_FUNCTION = Callable[[float, Optional[float], Tokenizer], Sequence[int]]
PREFIX_FUNCTION = Callable[[float, Optional[float], PreTrainedTokenizerBase],
Sequence[int]]


def ul2_prefix_function(
mask_ratio: float,
mean_length: Optional[float],
tokenizer: Tokenizer,
tokenizer: PreTrainedTokenizerBase,
) -> Sequence[int]:
"""Generates prefixes based on UL2 paper.
Expand Down Expand Up @@ -132,7 +131,7 @@ class MixtureOfDenoisersCollator:

def __init__(
self,
tokenizer: Tokenizer,
tokenizer: PreTrainedTokenizerBase,
max_seq_length: int,
decoder_only_format: bool = False,
span_mean_lengths_and_ratios: Optional[List] = None,
Expand Down Expand Up @@ -352,7 +351,7 @@ def __call__(self, examples: List[Dict[str,

def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: Tokenizer,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader[Dict]:
"""Constructor function for a Mixture of Denoisers dataloader.
Expand Down Expand Up @@ -527,7 +526,7 @@ def noise_token_sequence(
prefix_tokens: Optional[Sequence[int]],
max_raw_length: int,
max_seq_length: int,
tokenizer: Tokenizer,
tokenizer: PreTrainedTokenizerBase,
sentinel_token_ids: np.ndarray,
decoder_only_format: bool,
context_eos: bool,
Expand Down Expand Up @@ -678,7 +677,8 @@ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
"""
span_markers = np.less(np.arange(total_tokens - 1), num_spans -
1)[np.random.permutation(total_tokens - 1)]
span_start_indicator = np.concatenate([[0], span_markers])
span_start_indicator = np.concatenate([[0],
span_markers]) # type: ignore
span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
spans = np.arange(num_spans).reshape(1, -1)
span_lengths = np.sum(span_id == spans, axis=0)
Expand Down Expand Up @@ -715,12 +715,13 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],

# Ensure there's an end-of-sentence token at the end
if ensure_eos and (noised_tokens[-1] != eos_token_id):
noised_tokens = np.concatenate([noised_tokens, [eos_token_id]])
noised_tokens = np.concatenate([noised_tokens,
[eos_token_id]]) # type: ignore

return noised_tokens

# Masking at previous token
prev_token_mask = np.concatenate([[0], mask[:-1]])
prev_token_mask = np.concatenate([[0], mask[:-1]]) # type: ignore

# Decompose mask into start-of-span mask and non-start-of-span mask
start_of_noise_span_token = np.logical_and(mask,
Expand All @@ -739,7 +740,8 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],

# Ensure there's an end-of-sentence token at the end
if ensure_eos and (noised_tokens[-1] != eos_token_id):
noised_tokens = np.concatenate([noised_tokens, [eos_token_id]])
noised_tokens = np.concatenate([noised_tokens,
[eos_token_id]]) # type: ignore
return noised_tokens


Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _process_and_batch_encoder_decoder(
return batch


def ensure_list(x: Union[List, torch.Tensor]):
def ensure_list(x: Union[List, torch.Tensor]) -> List:
if isinstance(x, torch.Tensor):
x = list(x.flatten())
assert isinstance(x, list)
Expand Down
16 changes: 9 additions & 7 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import logging
import os
from typing import Union

import torch
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
Expand All @@ -20,10 +19,9 @@
# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
def build_finetuning_dataloader(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataLoader:
"""Builds a finetuning dataloader for training or evaluating.
Expand Down Expand Up @@ -115,6 +113,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
if tokenizer.pad_token is None: # type: ignore
tokenizer.pad_token = tokenizer.eos_token

dataset = None # for pyright
if cfg.dataset.get('remote') is not None:
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
Expand Down Expand Up @@ -166,6 +165,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

assert dataset is not None
return DataLoader(
dataset,
collate_fn=collate_fn,
Expand Down Expand Up @@ -235,7 +235,8 @@ def _validate_config(dataset_cfg: DictConfig):
)


def _build_hf_dataset_from_remote(cfg: DictConfig, tokenizer: Tokenizer):
def _build_hf_dataset_from_remote(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase):
"""Builds a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
Expand Down Expand Up @@ -313,7 +314,8 @@ def _build_hf_dataset_from_remote(cfg: DictConfig, tokenizer: Tokenizer):
return dataset


def _build_collate_fn(dataset_cfg: DictConfig, tokenizer: Tokenizer,
def _build_collate_fn(dataset_cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
Expand Down
13 changes: 6 additions & 7 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import datasets as hf_datasets
from omegaconf import DictConfig
from streaming import StreamingDataset
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerBase

__all__ = ['dataset_constructor']

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


def _tokenize_formatted_example(example: Dict[str, Any], tokenizer: Tokenizer):
def _tokenize_formatted_example(example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase):
if ('prompt' not in example) or ('response' not in example):
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
Expand Down Expand Up @@ -86,7 +85,7 @@ class StreamingFinetuningDataset(StreamingDataset):

def __init__(self,
local: str,
tokenizer: Tokenizer,
tokenizer: PreTrainedTokenizerBase,
remote: Optional[str] = None,
split: Optional[str] = None,
shuffle: bool = False,
Expand Down Expand Up @@ -162,7 +161,7 @@ def print_registered_tasks(self):
tasks = sorted(self._task_preprocessing_registry.keys())
print('\n'.join(tasks))

def get_preprocessing_fn_from_dict(self, mapping: dict):
def get_preprocessing_fn_from_dict(self, mapping: Union[Dict, DictConfig]):
"""Get a preprocessing function from a dictionary.
The dictionary maps column names in the dataset to "prompt" and "response".
Expand Down Expand Up @@ -256,7 +255,7 @@ def get_preprocessing_fn_from_str(self,
return preprocessing_fn

def build_from_hf(self, cfg: DictConfig, max_seq_len: int,
tokenizer: Tokenizer):
tokenizer: PreTrainedTokenizerBase):
"""Load a HuggingFace Datasets, preprocess, and tokenize.
Note: This function will drop examples where the prompt is longer than the max_seq_len
Expand Down
5 changes: 4 additions & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np
import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase


class BinPackWrapper:
Expand Down Expand Up @@ -312,7 +314,8 @@ def parse_args() -> Namespace:
raise ValueError('`num_packing_ratios` must be a positive integer.')
return args

def build_dataloader(cfg, tokenizer, device_batch_size):
def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
if cfg.name == 'text':
return build_text_dataloader(cfg, tokenizer, device_batch_size)
elif cfg.name == 'text_denoising':
Expand Down
Loading

0 comments on commit 9250e84

Please sign in to comment.