diff --git a/.gitignore b/.gitignore index d041a25c22..1dd80a8b6c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ my-copy-c4*/ my-copy-arxiv*/ *.jsonl* +!tests/eval/local_data/*.jsonl # WandB wandb/ diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 922f738e9a..012147ec20 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -19,49 +19,39 @@ hf_dynamic_modules_logger.addFilter(new_files_warning_filter) -from llmfoundry import algorithms, callbacks, loggers, optim, registry, utils -from llmfoundry.data import (ConcatTokensDataset, NoConcatDataset, - Seq2SeqFinetuningCollator, - build_finetuning_dataloader) -from llmfoundry.models.hf import ComposerHFCausalLM, ComposerHFT5 -from llmfoundry.models.layers.attention import ( - MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, - flash_attn_fn, scaled_multihead_dot_product_attention) -from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry import (algorithms, callbacks, cli, data, eval, interfaces, + loggers, metrics, models, optim, tokenizers, utils) +from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset +from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric +from llmfoundry.models.hf import ComposerHFCausalLM from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) -from llmfoundry.tokenizers import TiktokenTokenizerWrapper +from llmfoundry.optim import DecoupledLionW __all__ = [ - 'build_finetuning_dataloader', - 'Seq2SeqFinetuningCollator', - 'MPTBlock', - 'FFN_CLASS_REGISTRY', - 'MPTMLP', - 'build_ffn', + 'StreamingFinetuningDataset', + 'StreamingTextDataset', + 'InContextLearningDataset', + 'InContextLearningMetric', + 'ComposerHFCausalLM', 'MPTConfig', 'MPTPreTrainedModel', 'MPTModel', 'MPTForCausalLM', 'ComposerMPTCausalLM', - 'ComposerHFCausalLM', - 'ComposerHFT5', - 'scaled_multihead_dot_product_attention', - 'flash_attn_fn', - 'MultiheadAttention', - 'NoConcatDataset', - 'ConcatTokensDataset', - 'attn_bias_shape', - 'build_attn_bias', - 'build_alibi_bias', - 'optim', - 'utils', - 'loggers', + 'DecoupledLionW', 'algorithms', 'callbacks', - 'TiktokenTokenizerWrapper', - 'registry', + 'cli', + 'data', + 'eval', + 'interfaces', + 'loggers', + 'metrics', + 'models', + 'optim', + 'tokenizers', + 'utils', ] -__version__ = '0.7.0' +__version__ = '0.8.0.dev0' diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index a976c08060..c261a2086b 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -27,6 +27,8 @@ log = logging.getLogger(__name__) +__all__ = ['AsyncEval'] + REQUIRED_PARAMS_FOR_EVAL = { 'device_eval_batch_size', 'icl_tasks', # only required for eval, may not be specified in pure training diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 37faa14fdd..f00d68f760 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -20,6 +20,8 @@ log = logging.getLogger(__name__) +__all__ = ['CurriculumLearning'] + @experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): diff --git a/llmfoundry/callbacks/fdiff_callback.py b/llmfoundry/callbacks/fdiff_callback.py index 1237f32e22..2afcc94452 100644 --- a/llmfoundry/callbacks/fdiff_callback.py +++ b/llmfoundry/callbacks/fdiff_callback.py @@ -8,6 +8,8 @@ from composer.core import Callback, State from composer.loggers import Logger +__all__ = ['FDiffMetrics'] + class FDiffMetrics(Callback): """Rate of change of metrics. diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index baa72a7f66..62d54d1e6a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -12,8 +12,9 @@ import time from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +import numpy as np import torch import torch.nn as nn from composer.core import Callback, Event, State, Time, TimeUnit @@ -35,6 +36,8 @@ log = logging.getLogger(__name__) +__all__ = ['HuggingFaceCheckpointer'] + _LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE) @@ -158,8 +161,6 @@ def __init__( if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: - import numpy as np - # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work passed_metadata = mlflow_logging_config.get('metadata', {}) @@ -169,18 +170,17 @@ def __init__( default_input_example = { 'prompt': np.array(['What is Machine Learning?']) } - is_chat = mlflow_logging_config['task'].endswith( - 'chat') or mlflow_logging_config['metadata'].get( - 'task', '').endswith('chat') + is_chat = mlflow_logging_config['task'].endswith('chat') or ( + mlflow_logging_config['metadata'] is not None and + mlflow_logging_config['metadata'].get('task', + '').endswith('chat')) if is_chat: default_input_example = { - 'messages': - np.array([{ - 'role': 'user', - 'content': 'What is Machine Learning?' - }]) + 'messages': [{ + 'role': 'user', + 'content': 'What is Machine Learning?' + }] } - mlflow_logging_config.setdefault('example_no_conversion', True) mlflow_logging_config.setdefault('input_example', default_input_example) @@ -258,6 +258,16 @@ def _is_last_batch(self, state: State): return True assert state.max_duration is not None # for pyright + + epoch_complete = state.dataloader_len == state.timestamp.batch_in_epoch + second_to_last_epoch = state.max_duration.unit == TimeUnit.EPOCH and ( + state.timestamp.epoch == state.max_duration.value - 1) + # If the save interval is specified as exactly the same number of batches as the total duration, + # but the max duration is specified in epochs, we need a special case to identify we are on the last batch + # and should write the mlflow checkpoint. This should occur on the last batch of the final epoch. + if self.save_interval.unit == TimeUnit.BATCH and second_to_last_epoch and epoch_complete: + return True + # If the save interval is specified as 1dur, and the max duration is in epoch units # we need a special case to identify we are on the last batch and should write the mlflow checkpoint if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH: @@ -273,6 +283,23 @@ def _all_child_processes_done(self) -> bool: dist.all_reduce(x, reduce_operation='MAX') return x.item() == 0 + def transform_model_and_tokenizer( + self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase + ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + """Transform the model and tokenizer before saving. + + This allows a subclass to modify the model and tokenizer before saving. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be transformed. + + Returns: + Tuple[PreTrainedModel, PreTrainedTokenizerBase]: The transformed model and tokenizer. + """ + return model, tokenizer + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -405,6 +432,10 @@ def dtensor_to_tensor_hook( new_model_instance.load_state_dict(state_dict, assign=True) del state_dict + # Transform the model and tokenizer before saving + new_model_instance, original_tokenizer = self.transform_model_and_tokenizer( + new_model_instance, original_tokenizer) + log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: diff --git a/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py index fc906e0d87..89ee37cf0c 100644 --- a/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py +++ b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py @@ -9,6 +9,8 @@ from composer.loggers import Logger from composer.utils import dist +__all__ = ['MegaBlocksMoE_TokPerExpert'] + class MegaBlocksMoE_TokPerExpert(Callback): """Log tokens per expert for MegaBlocks MoE. @@ -44,7 +46,7 @@ class MegaBlocksMoE_TokPerExpert(Callback): Args: log_interval (int, optional): The interval on which to log (Default: 10). - log_every_layer (bool, optional): Enable logging ever layer's statisictics (True) or log + log_every_layer (bool, optional): Enable logging ever layer's statistics (True) or log only aggregate statistics (Default: False). all_reduce_stats (bool, optional): Enable aggregating statistics across gpus (True) or log statistics for GPU 0 (Default: False). diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index aaa68763f5..395a13111c 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -15,6 +15,8 @@ from composer.utils import (dist, format_name_with_dist_and_time, parse_uri, reproducibility) +__all__ = ['MonolithicCheckpointSaver'] + class MonolithicCheckpointSaver(Callback): """Save a monolithic checkpoint every N batches. diff --git a/llmfoundry/callbacks/resumption_callbacks.py b/llmfoundry/callbacks/resumption_callbacks.py index 751accc922..f910114a88 100644 --- a/llmfoundry/callbacks/resumption_callbacks.py +++ b/llmfoundry/callbacks/resumption_callbacks.py @@ -7,6 +7,8 @@ from composer.core import Callback, State from composer.loggers import Logger +from llmfoundry.utils.warnings import experimental_class + __all__ = [ 'GlobalLRScaling', 'LayerFreezing', @@ -15,6 +17,7 @@ log = logging.getLogger(__name__) +@experimental_class('GlobalLRScaling') class GlobalLRScaling(Callback): """GlobalLRScaling. @@ -52,6 +55,7 @@ def fit_start(self, state: State, logger: Logger) -> None: ] +@experimental_class('LayerFreezing') class LayerFreezing(Callback): """LayerFreezing. diff --git a/llmfoundry/callbacks/scheduled_gc_callback.py b/llmfoundry/callbacks/scheduled_gc_callback.py index 216fa2adb4..b210c99b16 100644 --- a/llmfoundry/callbacks/scheduled_gc_callback.py +++ b/llmfoundry/callbacks/scheduled_gc_callback.py @@ -8,6 +8,8 @@ from composer.core import Callback, State from composer.loggers import Logger +__all__ = ['ScheduledGarbageCollector'] + def gc_cuda(): """Garbage collect Torch (CUDA) memory.""" diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index 45d1f8237f..027ea7b07a 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -4,9 +4,14 @@ from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset from llmfoundry.data.dataloader import build_dataloader from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator, + StreamingFinetuningDataset, build_finetuning_dataloader) -from llmfoundry.data.text_data import (StreamingTextDataset, - build_text_dataloader) +from llmfoundry.data.packing import (BinPackCollator, auto_packing_ratio, + profile_packing) +from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, + StreamingTextDataset, + build_text_dataloader, + get_tokens_per_batch_func) from llmfoundry.registry import dataloaders dataloaders.register('text', func=build_text_dataloader) @@ -15,9 +20,15 @@ __all__ = [ 'Seq2SeqFinetuningCollator', 'build_finetuning_dataloader', + 'StreamingFinetuningDataset', 'StreamingTextDataset', 'build_text_dataloader', 'NoConcatDataset', 'ConcatTokensDataset', 'build_dataloader', + 'BinPackCollator', + 'auto_packing_ratio', + 'profile_packing', + 'ConcatenatedSequenceCollatorWrapper', + 'get_tokens_per_batch_func', ] diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index 92e4538d73..c7b018c5fb 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -11,6 +11,11 @@ from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase +__all__ = [ + 'ConcatTokensDataset', + 'NoConcatDataset', +] + class NoConcatDataset(IterableDataset): """An IterableDataset that returns text samples for MDSWriter. diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index a98526001a..61471420f8 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -10,6 +10,10 @@ from llmfoundry import registry from llmfoundry.utils.registry_utils import construct_from_registry +__all__ = [ + 'build_dataloader', +] + def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec: diff --git a/llmfoundry/data/finetuning/__init__.py b/llmfoundry/data/finetuning/__init__.py index 9d10a17cfa..3b5c277199 100644 --- a/llmfoundry/data/finetuning/__init__.py +++ b/llmfoundry/data/finetuning/__init__.py @@ -3,5 +3,16 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.finetuning.tasks import (StreamingFinetuningDataset, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example) -__all__ = ['Seq2SeqFinetuningCollator', 'build_finetuning_dataloader'] +__all__ = [ + 'Seq2SeqFinetuningCollator', + 'build_finetuning_dataloader', + 'dataset_constructor', + 'tokenize_formatted_example', + 'is_valid_ift_example', + 'StreamingFinetuningDataset', +] diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py index 6e3babd657..7d592483f1 100644 --- a/llmfoundry/data/finetuning/collator.py +++ b/llmfoundry/data/finetuning/collator.py @@ -10,6 +10,10 @@ log = logging.getLogger(__name__) +__all__ = [ + 'Seq2SeqFinetuningCollator', +] + # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 1d8711d280..e72ee29719 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -23,6 +23,10 @@ log = logging.getLogger(__name__) +__all__ = [ + 'build_finetuning_dataloader', +] + # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 4906cea151..e6a3afb188 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -42,6 +42,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: Tuple, Union, cast) import datasets as hf_datasets +import datasets.exceptions as hf_exceptions import huggingface_hub as hf_hub import numpy as np from composer.utils import dist @@ -61,6 +62,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: InvalidPromptTypeError, InvalidResponseTypeError, InvalidRoleError, + MisconfiguredHfDatasetError, NotEnoughChatDataError, TooManyKeysInExampleError, UnableToProcessPromptResponseError, @@ -70,7 +72,12 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: log = logging.getLogger(__name__) -__all__ = ['dataset_constructor'] +__all__ = [ + 'dataset_constructor', + 'tokenize_formatted_example', + 'is_valid_ift_example', + 'StreamingFinetuningDataset', +] _ALLOWED_RESPONSE_KEYS = {'response', 'completion'} _ALLOWED_PROMPT_KEYS = {'prompt'} @@ -82,7 +89,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, '.downloaded_finetuning')) -SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] +SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] PromptResponseDict = Mapping[str, str] ChatFormattedDict = Mapping[str, List[Dict[str, str]]] @@ -838,6 +845,10 @@ def dataset_mapper(example: Dict): if dist.get_local_rank() == 0: os.remove(signal_file_path) + if isinstance(error, hf_exceptions.DatasetGenerationError): + log.error('Huggingface DatasetGenerationError during data prep.') + raise MisconfiguredHfDatasetError(dataset_name=dataset_name, + split=split) if error is not None: log.error('Error during data prep') raise error diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 9696f967ca..3d525def47 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -13,6 +13,12 @@ log = logging.getLogger(__name__) +__all__ = [ + 'BinPackCollator', + 'auto_packing_ratio', + 'profile_packing', +] + class BinPackCollator: """Utility collator for packing to reduce padding.""" diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index fc31b890b0..a59098323b 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -22,6 +22,13 @@ log = logging.getLogger(__name__) +__all__ = [ + 'StreamingTextDataset', + 'build_text_dataloader', + 'ConcatenatedSequenceCollatorWrapper', + 'get_tokens_per_batch_func', +] + class StreamingTextDataset(StreamingDataset): """Generic text dataset using MosaicML's StreamingDataset. diff --git a/llmfoundry/eval/__init__.py b/llmfoundry/eval/__init__.py new file mode 100644 index 0000000000..f2425296b6 --- /dev/null +++ b/llmfoundry/eval/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.eval.datasets.in_context_learning_evaluation import ( + InContextLearningCodeEvalDataset, InContextLearningDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset, + InContextLearningSchemaTaskDataset, get_icl_task_dataloader) +from llmfoundry.eval.metrics.nlp import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, InContextLearningMetric, + InContextLearningMultipleChoiceAccuracy) + +__all__ = [ + 'InContextLearningDataset', + 'InContextLearningLMTaskDataset', + 'InContextLearningMultipleChoiceTaskDataset', + 'InContextLearningSchemaTaskDataset', + 'InContextLearningCodeEvalDataset', + 'InContextLearningGenerationTaskWithAnswersDataset', + 'get_icl_task_dataloader', + 'InContextLearningMetric', + 'InContextLearningLMAccuracy', + 'InContextLearningMultipleChoiceAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', + 'InContextLearningCodeEvalAccuracy', + 'InContextLearningLMExpectedCalibrationError', + 'InContextLearningMCExpectedCalibrationError', +] diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py new file mode 100644 index 0000000000..0be9882b0c --- /dev/null +++ b/llmfoundry/eval/datasets/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Natively supported in-context learning evaluation datasets.""" + +from llmfoundry.eval.datasets.in_context_learning_evaluation import ( + InContextLearningCodeEvalDataset, InContextLearningDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset, + InContextLearningSchemaTaskDataset, get_icl_task_dataloader) + +# isort: off +from llmfoundry.eval.datasets.utils import ( + MultiTokenEOSCriteria, convert_tokens_to_tensors, get_continuation_span, + get_fewshot_sample_idxs, make_padded_input, stop_sequences_criteria, + strip_data, tokenizer_needs_prefix_space, trim_context) +# isort: on + +__all__ = [ + 'InContextLearningDataset', + 'InContextLearningGenerationTaskWithAnswersDataset', + 'InContextLearningLMTaskDataset', + 'InContextLearningCodeEvalDataset', + 'InContextLearningMultipleChoiceTaskDataset', + 'InContextLearningSchemaTaskDataset', + 'get_icl_task_dataloader', + 'MultiTokenEOSCriteria', + 'strip_data', + 'tokenizer_needs_prefix_space', + 'trim_context', + 'get_continuation_span', + 'make_padded_input', + 'convert_tokens_to_tensors', + 'get_fewshot_sample_idxs', + 'stop_sequences_criteria', +] diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py new file mode 100644 index 0000000000..447855f953 --- /dev/null +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -0,0 +1,1792 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +import json +import logging +import os +import random +import warnings +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union + +import torch +import transformers +from composer.core import DataSpec +from composer.core.data_spec import _default_split_batch, _split_list +from composer.datasets.utils import stop_sequences_criteria +from composer.utils import MissingConditionalImportError, dist, get_file +from datasets import Dataset as HFDataset +from datasets import IterableDataset, load_dataset +from torch.utils.data import DataLoader, Dataset + +from llmfoundry.eval.datasets.utils import (convert_tokens_to_tensors, + get_continuation_span, + get_fewshot_sample_idxs, + make_padded_input, strip_data, + tokenizer_needs_prefix_space, + trim_context) +from llmfoundry.utils.warnings import VersionedDeprecationWarning + +log = logging.getLogger(__name__) + +# Allow models to have slightly more tokens than were used in the most verbose CoT in the dataset +_MAX_ANSWER_BUFFER_LENGTH = 10 + +__all__ = [ + 'InContextLearningDataset', + 'InContextLearningLMTaskDataset', + 'InContextLearningMultipleChoiceTaskDataset', + 'InContextLearningSchemaTaskDataset', + 'InContextLearningCodeEvalDataset', + 'InContextLearningGenerationTaskWithAnswersDataset', + 'get_icl_task_dataloader', +] + + +class InContextLearningDataset(Dataset): + r"""A base dataset that constructs batches for in-context learning task. + + evaluations. The dataset format is expected to be a local jsonl file, a + cloud link to a jsonl file, or a Hugging Face dataset link. 'context' refers + to the input a model will receive before generating an output. For example, + the question in question answering tasks, the preceding text in a language + modeling task, or the document and question regarding the document in a + document understanding task. 'example' refers to a loaded dictionary, + generally containing a context, an answer, and any other information needed + to run the task. 'answer' refers to the desired output of the model. + + When creating a new ICL Dataset, it is likely that you will need to reimplement the following methods: + + - construct_context(): Takes a single example dictionary and formulates the context as a string for that eval question. + - get_answer_from_example(): Takes a single example dictionary and formulates the correct, ground truth answer as a string. + - tokenize_example(): Tokenizes the example and adds any extra content from the original dictionary that needs to be passed downstream. + - read_dataset(): Loads the dataset and does basic parsing. If additional parsing must be done, this is a good place to do so (See InContextLearningGenerationTaskWithAnswersDataset.read_dataset()) + + Additionally, base_batch and batch_mapping must be defined. + + - base_batch (Dict): The base dictionary that the dataset will use to construct a batch. This should contain static values, like generation_kwargs or mode, + and empty lists for values that will need to be accumulated from each example. + NOTE: Sometimes you will need to set base_batch directly after the init call, e.g. in order to use class variables + like self.pad_tok_id or self.max_answer_length. If you manually set generation_kwargs this way, you'll need to call self.update_generation_kwargs() + after setting self.base_batch. + - batch_mapping (Dict): A mapping with keys that are keys in the batch and values that are columns in the loaded dataset. + collate_fn will use this mapping to create batches from self.dataset. + + Args: + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri prepended with ``hf://``. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + A local dataset must consist of rows of JSON data points with task dependent fields. + The default keys expected are "context" and "answer". + tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. + max_seq_len (int): The maximum sequence length supported by the model. + pad_tok_id (int): The special token used for padding batches. + num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. + fewshot_random_seed (int): Random seed to use for fewshot sampling. + prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). + example_delimiter (str): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. + continuation_delimiter: (str): Separator inserted between context and answer in each example (e.g. '\\nA: '). + destination_path (str): Temporary path to store downloaded datasets. + prelimiter (str): Text to be prepended before each context, including few shot examples (e.g. "Question: "). + context_key (str): The key in the loaded dataset that contains the context. + answer_key (str): The key in the loaded dataset that contains the answer. + strip_dataset (bool): Boolean for whether to strip whitespace from data. Trailing whitespace can cause degenerative outputs, + so unless whitespace should be preserved (for example in code), this should be set to True. + padding_side (str): Side of the content and answer on which to apply padding. Can be either 'right' or 'left'. + tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses. + padding_size (int): The final size of the tensor after padding. Defaults to max_sequence_length. + base_batch (Dict): The base dictionary upon which a batch is created. See above for more details. + base_mapping (Dict): A mapping of batch keys to dataset columns, used to create batches. See above for more details. + hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. + generation_kwargs (Dict): A dictionary containing keyword arguments to be passed along to the model's generate function. + static_keys (List): A list of the key values which will be broadcast across a batch (e.g. it is the same for each batch element). + list_keys (List): A list of the batch keys whose values are lists which will be split using list methods during calls to split_batch. + tensor_keys (List): A list of the batch keys whose values are tensors which will be split using tensor methods during calls to split_batch. + """ + + def __init__( + self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + fewshot_random_seed: int, + prompt_string: str, + example_delimiter: str, + continuation_delimiter: str, + destination_path: str, + prelimiter: str = '', + context_key: str = 'context', + answer_key: str = 'answer', + strip_dataset: bool = True, + padding_side: str = 'right', + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + base_batch: Optional[Dict] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + static_keys: Optional[List] = None, + list_keys: Optional[List] = None, + tensor_keys: Optional[List] = None, + ): + self.tokenizer = tokenizer + self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer) + + self.max_seq_len = max_seq_len + self.pad_tok_id = pad_tok_id + self.num_fewshot = num_fewshot + self.padding_side = padding_side + self.padding_size = padding_size if padding_size else self.max_seq_len + self.prelimiter = prelimiter + self.example_delimiter = example_delimiter + self.continuation_delimiter = continuation_delimiter + self.context_key = context_key + self.answer_key = answer_key + self.tokenize_labels = tokenize_labels + self.batch_mapping = batch_mapping or {} + self.base_batch = base_batch or {} + if generation_kwargs: + self.update_generation_kwargs(generation_kwargs) + + self.static_keys = static_keys + self.list_keys = list_keys + self.tensor_keys = tensor_keys + + hf_loading_vars = hf_loading_vars or {} + self.dataset: HFDataset = self.read_dataset(dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map) + self.strip_data = strip_dataset + if self.strip_data: + self.dataset = self.dataset.map(strip_data) + + fewshot_rng = random.Random(fewshot_random_seed) + self.dataset: HFDataset = self.dataset.map( + self._prep_example, + with_indices=True, + fn_kwargs={ + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'fewshot_rng': fewshot_rng, + }, + ) + + def __getitem__(self, index: int) -> Dict: + return self.dataset[index] + + def __len__(self) -> int: + return len(self.dataset) + + def get_num_samples_in_batch(self, batch: Dict) -> int: + return batch['input_ids'].shape[0] + + def update_generation_kwargs(self, generation_kwargs: Dict) -> None: + r"""Updates self.base_batch with the passed in generation_kwargs. + + This must be run after self.base_batch is set (for example, if + self.base_batch is set after __init__() is run, likely because + base_batch needs a class variable like self.pad_tok_id or + self.max_answer_length). + + Args: + generation_kwargs (Dict): Keyword arguments that be written into base_batch['generation_kwargs'] + """ + if generation_kwargs: + if 'generation_kwargs' not in self.base_batch: + self.base_batch['generation_kwargs'] = {} + self.base_batch['generation_kwargs'].update(generation_kwargs) + + def read_dataset( + self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Optional[Dict[str, Any]] = None, + hf_parsing_map: Optional[Dict[str, Any]] = None) -> 'HFDataset': + """Reads a dataset and handles parsing it from HuggingFace. + + Args: + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + destination_path (str): A local path where the data will be stored + hf_loading_vars (Dict): If parsing from HuggingFace, keyword args that will be passed into load_dataset + hf_parsing_map (Dict): Dictionary in the form of {icl_key: [hf_col1, hf_col2]} that will map one or more hf columns, in order, to ICL dataset columns + + Returns: + dataset: A loaded HF dataset + """ + from datasets import \ + Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import \ + load_dataset # pyright: ignore[reportGeneralTypeIssues] + if 'hf://' in dataset_uri: + dataset_uri = dataset_uri.replace('hf://', '') + if hf_loading_vars is None: + hf_loading_vars = {} + dataset = load_dataset(dataset_uri, **hf_loading_vars) + if hf_parsing_map: + dataset_parsing_func = lambda example: { + k: ' '.join([str(example[col]) for col in v]) + for k, v in hf_parsing_map. + items( # pyright: ignore[reportOptionalMemberAccess] + ) + } + assert isinstance(dataset, HFDataset) + dataset = dataset.map(dataset_parsing_func, + remove_columns=dataset.column_names) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', + data_files=destination_path, + split='train', + streaming=False) + assert isinstance(dataset, HFDataset) + return dataset + + def _generate_few_shot_prompt( + self, + num_fewshot: int, + example_idx: int, + preamble: str, + fewshot_rng: random.Random, + ) -> str: + """Formats the fewshot prompt for test example `example_idx`. + + Randomly selects `num_fewshot` samples from the dataset (excluding the example at `example_idx`) and constructs + contexts with answers appended. + + Returns the formatted prompt_string + concatenated list of formatted few shot examples as a string. + + Args: + num_fewshot (int): Number of examples to prepend + example_idx (int): Current example idx + preamble (str): Text to occur at the beginning of the task. Generally instructions or a prompt. + fewshot_rng (random.Random): Seeded sampler to chose samples with + + Returns: + str: The original preamble with num_fewshot examples appended + """ + few_shot_text = preamble + + if num_fewshot > 0: + fewshot_idxs = get_fewshot_sample_idxs( + len(self.dataset), + num_fewshot, + example_idx, + fewshot_rng, + ) + for fewshot_idx in fewshot_idxs: + ctxt = self.construct_context( + self.dataset[fewshot_idx], + few_shot_text, + add_answer=True, + ) + few_shot_text += ctxt + + return few_shot_text + + def construct_context(self, + example: Dict, + preceding_text: str = '', + add_answer: bool = False) -> str: + """Takes an example and constructs a context, i.e. the input the model. + + reads for this example. Optionally adds the correct answer (for fewshot + examples) and handles example delimiters. + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, used as a check for prepending self.example_delimiter + add_answer (bool): Bool for whether or not to add the answer on the end of the context (e.g. for fewshot examples) + + Returns: + str: The constructed context. The default output context is + formatted as follows: f'{self.prelimiter}{example[self.context_key]}{self.continuation_delimiter}' + """ + ctxt = example[self.context_key] + ctxt = f'{self.prelimiter}{ctxt}' + if len(preceding_text) > 0: + ctxt = f'{self.example_delimiter}{ctxt}' + ctxt = f'{ctxt}{self.continuation_delimiter}' + if add_answer: + ctxt = f'{ctxt}{self.get_answer_from_example(example, in_context=add_answer)}' + return ctxt + + def get_answer_from_example(self, + example: Dict[str, Any], + in_context: bool = False) -> str: + """Returns the answer from the example. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The answer in the example + """ + cont = example[self.answer_key] + if self.prefix_space and not cont.startswith(' ') and not in_context: + cont = f' {cont}' + return cont + + def _fix_eos_on_preamble(self, input_ids: List[int]) -> List[int]: + """If the input_ids is empty then input_ids will be a 0-length List. + + unless the tokenizer adds special tokens to empty strings (e.g. OPT + tokenizer). If there is an EOS token added, we need to remove it so it + is not in the middle of the prompt, as the specific eval question's + prompt will follow the input_ids. + + Args: + input_ids (List): The tokenized input + + Returns: + input_ids: The tokenized input conditionally edited + """ + if (self.tokenizer.eos_token_id is not None and len(input_ids) > 1 and + input_ids[-1] == self.tokenizer.eos_token_id): + input_ids = input_ids[:-1] + return input_ids + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctxt (str): The specific example's derived context + example (Dict): The example as a dictionary. Used for additional processing in inherited classes. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = {} + # Always add special tokens to preamble + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + if self.strip_data: + # rstrip context because a prompt ending in a space results in degenerate output + ctxt = ctxt.rstrip() + # Never add special tokens to context + tokenized_context = self.tokenizer( + ctxt, add_special_tokens=False)['input_ids'] + assert isinstance(preamble, list) + assert isinstance(tokenized_context, list) + + tokenized_context = preamble + tokenized_context + + if self.tokenize_labels: + # Never add special tokens to answer + tokenized_answer = self.tokenizer( + self.get_answer_from_example(example), + add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_answer, list) + trimmed_context = trim_context(tokenized_context, tokenized_answer, + self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = get_continuation_span( + trimmed_context, tokenized_answer) + padded_context = make_padded_input(trimmed_context, + tokenized_answer, + self.padding_size, + self.pad_tok_id, + self.padding_side) + + tokenized_example[self.context_key] = padded_context + tokenized_example[self.answer_key] = tokenized_answer + tokenized_example['continuation_indices'] = continuation_indices + else: + assert isinstance(tokenized_context, list) + trimmed_context = trim_context( + tokenized_context, + [], + self.padding_size, + ) + assert isinstance(trimmed_context, list) + padded_context = make_padded_input(trimmed_context, [], + self.padding_size, + self.pad_tok_id, + self.padding_side) + + tokenized_example[self.context_key] = padded_context + tokenized_example[self.answer_key] = self.get_answer_from_example( + example) + + return tokenized_example + + def _prep_example( + self, + example: Dict, + example_idx: int, + num_fewshot: int, + prompt_string: str, + fewshot_rng: random.Random, + ) -> Dict[str, Any]: + """Prepares a single example from a HF Dataset into tokenized format. + + with prompt and fewshot examples. + + Each task consists of a context and a continuation as well as an optional prompt and optional list of + example context/continuation pairs which precede the test context/continuation pair. + + Args: + example (Dict): A Dictionary from the hf dataset + example_idx (int): The index of example + num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair + prompt_string (str): The prompt to prepend to all inputs + fewshot_rng (random.Random): Random number generator to use for fewshot sampling + + Returns: + Dict: Contains a dictionary with the tokenized data + """ + prompt_and_fewshot = self._generate_few_shot_prompt( + num_fewshot, example_idx, prompt_string, fewshot_rng) + ctxt = self.construct_context(example, + prompt_and_fewshot, + add_answer=False) + tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, + example) + return tokenized_example + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """The function that the dataloader uses to accumulate data into. + + batches. + + Args: + data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) + + Returns: + Dict: Dictionary for a single batch + """ + batch = copy.deepcopy(self.base_batch) + for data_pair in data: + for batch_key, data_key in self.batch_mapping.items(): + batch[batch_key].append(data_pair[data_key]) + if 'continuation_indices' in data_pair: + batch['continuation_indices'].append( + data_pair['continuation_indices']) + + batch = convert_tokens_to_tensors(batch, self.tokenize_labels) + batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) + return batch + + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Handling for certain specialty columns that must be split into. + + batches in different formats. + + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + List: List of chunked batches + """ + # Don't split kwargs that don't change + # Normally split torch tensors + # List split lists of strings + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.') + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting until we know num_chunks + pass + elif k in self.list_keys: + chunked[k] = _split_list(v, microbatch_size) + elif k in self.tensor_keys: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + batched_list = [ + {k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks) + ] + return batched_list + + +class InContextLearningGenerationTaskWithAnswersDataset(InContextLearningDataset + ): + """A dataset that constructs batches for in-context learning generation. + + tasks with answers. Generation tasks evaluate a model's ability to + generate responses and score them against a set of gold-standard answers. + + The input format is expected to be a jsonl file with the following fields: + - context: The question + - answer: The preferred answer to the question + - aliases: A list of aliases for the answer + + See InContextLearningDataset for more details. + + Additional Args: + cot_delimiter (str): Delimiter to place between the chain of thought and continuations. + early_stopping_criteria (Optional[List[str]]): Optional strings to trigger early stopping. + do_normalization (bool): Flag indicating whether to normalize generations before providing output. + """ + + def __init__(self, + cot_delimiter: str = '', + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True, + *args: Any, + **kwargs: Any): + if kwargs['tokenizer'].eos_token_id is None: + raise ValueError( + '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`' + ) + self.cot_delimiter = cot_delimiter + self.has_cot = False + self.max_answer_length = 0 + static_keys = [ + 'mode', 'cot_delimiter', 'generation_kwargs', 'do_normalization', + 'stopping_criteria' + ] + tensor_keys = ['input_ids', 'attention_mask'] + list_keys = ['labels'] + super().__init__(padding_side='left', + tokenize_labels=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + *args, + **kwargs) + # NOTE: set these after init call because they take class vars + self.early_stopping_criteria = early_stopping_criteria + self.base_batch = { + 'input_ids': [], + 'mode': 'generate', + 'labels': [], + 'cot_delimiter': self.cot_delimiter, + 'stopping_criteria': early_stopping_criteria, + 'do_normalization': do_normalization, + 'generation_kwargs': { + 'pad_token_id': self.pad_tok_id, + 'use_cache': True, + 'eos_token_id': self.tokenizer.eos_token_id, + 'max_new_tokens': max(self.max_answer_length, 1) + }, + } + self.batch_mapping = { + 'input_ids': self.context_key, + 'labels': 'aliases', + } + if 'generation_kwargs' in kwargs: + self.update_generation_kwargs(kwargs['generation_kwargs']) + + def read_dataset( + self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Dict, + hf_parsing_map: Dict, + ) -> 'HFDataset': + dataset = super().read_dataset(dataset_uri, destination_path, + hf_loading_vars, hf_parsing_map) + self.has_cot = 'chain_of_thought' in dataset.features + dataset = dataset.map( + lambda examples: { + 'context': + examples['context'], + 'answer': + examples['answer'], + 'aliases': + set([examples['answer']] + examples.get('aliases', [])), + 'chain_of_thought': + examples.get('chain_of_thought', ''), + }) + self.max_answer_length = self._get_max_answer_length(dataset) + # NOTE: This is the only time we use the class variable padding_size. + if self.max_seq_len < self.max_answer_length: + log.warning(f'`max_seq_len` {self.max_seq_len} was less than `max_answer_len`: {self.max_answer_length}' \ + + ' setting `max_seq_len`=`max_answer_len`') + self.max_seq_len = self.max_answer_length + self.padding_size = self.max_seq_len - self.max_answer_length + return dataset + + def get_answer_from_example(self, + example: Dict, + in_context: bool = False) -> str: + """Returns the answer from the example. Applies chain of thought if. + + self.has_cot is marked as true. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The answer in from the example with chain of thought and delimiter if needed + """ + if self.has_cot: + return f'{example["chain_of_thought"]}{self.cot_delimiter}{example[self.answer_key]}' + else: + return example[self.answer_key] + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Run text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derived context + example (Dict): The example as a dictionary. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, + example) + tokenized_example['aliases'] = list(example.get('aliases', [])) + return tokenized_example + + def _get_max_answer_length(self, dataset: Iterable[dict]) -> int: + """Loops over the dataset and finds the longest answer length. + + Returns: + int: The maximum answer length with an additional buffer of 10 if chain of thought is present + """ + max_answer_length = 0 + for example in dataset: + all_answers = [example[self.answer_key]] + list( + example.get('aliases', [])) + for answer in all_answers: + if self.has_cot: + response = ( + f'{example["chain_of_thought"]}{self.cot_delimiter}{answer}' + ) + else: + response = answer + tokenized_response = self.tokenizer(response)['input_ids'] + assert isinstance(tokenized_response, list) + max_answer_length = max(max_answer_length, + len(tokenized_response)) + max_answer_length = max_answer_length + ( + _MAX_ANSWER_BUFFER_LENGTH if len(self.cot_delimiter) > 0 else 0) + return max_answer_length + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + batch = super().collate_fn(data) + batch_size = batch['input_ids'].shape[0] + stopping_criteria = None + if self.early_stopping_criteria: + if stop_sequences_criteria is None: # pyright: ignore [reportUnnecessaryComparison] + raise MissingConditionalImportError( + extra_deps_group='nlp', + conda_package='transformers', + conda_channel='conda-forge') + stopping_criteria = stop_sequences_criteria( + self.tokenizer, self.early_stopping_criteria, batch_size) + batch['generation_kwargs']['stopping_criteria'] = stopping_criteria + return batch + + +class InContextLearningLMTaskDataset(InContextLearningDataset): + """A dataset that constructs batches for in-context learning language. + + modeling evaluation. Language modeling tasks test a model's ability to + properly predict tokens based on preceding tokens. + + The input format is expected to be a jsonl file with the following fields: + - context: Preceding text + - continuation: The expected continuation + + See InContextLearningDataset for more details. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(answer_key='continuation', + static_keys=['mode'], + tensor_keys=[ + 'input_ids', 'continuation_indices', 'labels', + 'attention_mask' + ], + base_batch={ + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [] + }, + batch_mapping={ + 'input_ids': 'context', + 'labels': 'context' + }, + padding_side='right', + *args, + **kwargs) + + +class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): + """A dataset that construct batches for in-context learning multiple choice. + + evaluation. + + If each question has N answer choices, we construct N distinct inputs per question. In order to ensure + consistency across multi-GPU, we set the batch size to be `min(N, batch_size)` so that all N + inputs per question can stored in the same batch. + + The default input format is a jsonl file with the following fields: + - query: The preceding text, question, or document relevant to the choices + - gold: Index of the correct choice under 'choices' + - choices: A list of strings, each being one of the potential choices + + Each batch then consists of ``|batch_size // N|`` distinct questions and has the following the structure. + - input_ids: Input tensor ``|batch x seqlen x # tokens|`` + - continuation_indices: List of ``|batch|`` consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - labels: Identical to the input, used by the model to calculate loss/metrics + - gold_indices: List of length ``|batch_size // N|`` indicating for each question, which of the answers is correct (via an integer [0, N-1]) + - choice_groupings: Indicates which indices of the batch correspond to which questions + + Additional Args: + choices_key (str): The key under which the choices are stored in the saved dataset. Defaults to 'choices'. + """ + + def __init__(self, + choices_key: str = 'choices', + static_keys: Optional[List] = None, + list_of_tensors_keys: Optional[List] = None, + list_of_tuples_keys: Optional[List] = None, + list_of_primitives: Optional[List] = None, + *args: Any, + **kwargs: Any): + self.choices_key = choices_key + base_batch = { + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + 'gold_indices': [], + 'choice_groupings': [], + } + context_key = kwargs.pop('context_key', 'query') + static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) + tensor_keys = kwargs.pop('tensor_keys', + ['input_ids', 'labels', 'attention_mask']) + self.list_of_tensors_keys = list_of_tensors_keys or [ + 'continuation_indices' + ] + self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] + self.list_of_primitives = list_of_primitives or ['gold_indices'] + super().__init__(context_key=context_key, + base_batch=base_batch, + static_keys=static_keys, + tensor_keys=tensor_keys, + padding_side='right', + *args, + **kwargs) + self.num_choices = len(self.dataset[0][self.choices_key]) + self.batch_mapping_per_choice = { + 'input_ids': 'context', + 'labels': 'context' + } + self.batch_map_per_example = {'gold_indices': 'gold'} + + def get_answer_from_example(self, + example: Dict, + in_context: bool = False) -> str: + """Returns the correct answer from the example's choices. + + Args: + example (Dict): The example from which to retrieve the answer + + Returns: + str: The full string of the correct answer based on the 'gold' key + """ + choices = example[self.choices_key] + gold_idx = example['gold'] + return choices[gold_idx] + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derived context + example (Dict): The example as a dictionary. + + Returns: + Dict: Dictionary with the tokenized data + """ + # NOTE: some of this is repeated from super class but for loop makes things considerably different + tokenized_example = {} + # Always add special tokens to preamble + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + if self.strip_data: + # rstrip context because a prompt ending in a space results in degenerate output + ctxt = ctxt.rstrip() + # Never add special tokens to context + tokenized_context = self.tokenizer( + ctxt, add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_context, list) + tokenized_context = preamble + tokenized_context + + tokenized_example[self.context_key] = [] + tokenized_example[self.answer_key] = [] + tokenized_example['continuation_indices'] = [] + # NOTE: Treating tokenize_labels as True for all MC datasets (required for our MC accuracy metric) + for choice in example[self.choices_key]: + if self.prefix_space: + choice = f' {choice}' if not choice.startswith(' ') else choice + + # Never add special tokens to answer + tokenized_answer = self.tokenizer( + choice, add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_context, list) + assert isinstance(tokenized_answer, list) + trimmed_context = trim_context(tokenized_context, tokenized_answer, + self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = get_continuation_span( + trimmed_context, tokenized_answer) + padded_context = make_padded_input( + trimmed_context, + tokenized_answer, + self.padding_size, + self.pad_tok_id, + self.padding_side, + ) + + tokenized_example[self.context_key].append(padded_context) + tokenized_example[self.answer_key].append(tokenized_answer) + tokenized_example['continuation_indices'].append( + continuation_indices) + + tokenized_example['gold'] = example['gold'] + return tokenized_example + + def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """The function that the dataloader uses to accumulate data into. + + batches. We run each distinct query + answer choice through the model + separately and determine which answer has the lowest per-token- + perplexity. + + If each question has N possible choices, all N must be grouped together as distinct elements of the batch + since the batch may consist of multiple questions, the choice_groupings indicates + which contiguous sequences of elements in the batch correspond to which question + gold_indices indicates which of the [0, N-1] choices is the correct one for each question. + Args: + data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) + + Returns: + Dict: Dictionary for a single batch + """ + batch = copy.deepcopy(self.base_batch) + for data_pair in data: + choice_start_idx = len(batch['continuation_indices']) + # NOTE: not using batch_mapping + for i, context_enc in enumerate(data_pair[self.context_key]): + batch['input_ids'].append(context_enc) + batch['continuation_indices'].append( + data_pair['continuation_indices'][i]) + batch['labels'].append(context_enc) + + batch['gold_indices'].append(data_pair['gold']) + choice_end_idx = len(batch['continuation_indices']) + batch['choice_groupings'].append((choice_start_idx, choice_end_idx)) + + batch = convert_tokens_to_tensors(batch, self.tokenize_labels) + batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) + return batch + + def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int: + return batch['input_ids'].shape[0] // self.num_choices + + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Split batch while ensuring all continuations are in the same. + + microbatch. + + In ICL Multiple Choice, we duplicate each data point for each possible continuation. + When splitting a batch, we have logical example, which refer to one possible question, + and real example, which refers to one possible continuation. As example count and + microbatch_size are tracked in logical example, we split logical attributes by + microbatch_size and real attributes by microbatch_size * num_choices. + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + list: List of chunked batches + """ + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.') + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting primitives until we know num_chunks + pass + elif type(v) == list: + # list of tensors - 'continuation_indices' + if k in self.list_of_tensors_keys: + chunked[k] = _split_list(v, + microbatch_size * self.num_choices) + # list of tuples - 'choice_groupings' + elif k in self.list_of_tuples_keys: + chunked[k] = _split_list(v, microbatch_size) + # list - 'gold_indices' + elif k in self.list_of_primitives: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in list splitting') + elif k in self.tensor_keys: + chunked[k] = _default_split_batch( + v, microbatch_size * self.num_choices) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + # Broadcast primitives to all chunks + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + return [ + {k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks) + ] + + +class InContextLearningSchemaTaskDataset( + InContextLearningMultipleChoiceTaskDataset): + """A dataset that constructs batches for in-context learning schema. + + evaluation. A schema task involves sentences with a fill-in-the-blank where + the user needs to choose the correct word to fill in from a set of N + options. We use the partial evaluation technique from + https://arxiv.org/abs/1806.02847 to determine the model's choice of fill-in + word. + + The default input format is a jsonl file with the following fields: + - context_options: List of strings corresponding to possible preceding context options for the continuation + - gold: Index of the correct context from 'context_options' + - continuation: The finishing continuation + + Each batch then consists of ``batch_size // N`` distinct tasks and has the following the structure + - input_ids: Input tensor ``batch x seqlen x # of tokens`` + - continuation_indices: List of ``batch`` consisting of tensors indicating which indices in the sequence correspond to the question answer (aka continuation) + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - labels: Identical to the input, used by the model to calculate loss/metrics + - gold_indices: List of length ``batch_size // N`` indicating for each question, which of the answers is correct (via an integer [0, N-1]) + - choice_groupings: Indicates which indices of the batch correspond to which questions + """ + + def __init__(self, + choices_key: str = 'context_options', + *args: Any, + **kwargs: Any): + static_keys = ['mode'] + tensor_keys = ['input_ids', 'labels', 'attention_mask'] + list_of_tensors_keys = ['continuation_indices'] + super().__init__(choices_key=choices_key, + context_key=choices_key, + static_keys=static_keys, + tensor_keys=tensor_keys, + list_of_tensors_keys=list_of_tensors_keys, + *args, + **kwargs) + self.base_batch = { + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + 'gold_indices': [], + 'choice_groupings': [], + } + + def construct_context(self, + example: Dict[str, Any], + preceding_text: str = '', + add_answer: bool = False) -> str: + """Takes a example and constructs a context with the correct context. + + for. + + the example's continuation. + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, needed to if self.example_delimiter is needed at the beginning + add_answer (bool): This will always be true when calling this function for SchemaTaskDataset + + Returns: + str: The single correct context for a given continuation + """ + context_options = example[self.choices_key] + gold_idx = example['gold'] + continuation = example['continuation'] + context = context_options[gold_idx] + if len(preceding_text) > 0: + context = f'{self.example_delimiter}{context}' + context = f'{self.prelimiter}{context}{self.continuation_delimiter}{continuation}' + return context + + def _construct_multiple_contexts(self, + example: Dict, + preceding_text: str = '') -> List[str]: + """Takes a example and constructs all contexts. + + Optionally, appends this to preceding text (such as a prompt or fewshot examples). + + Args: + example (Dict): The example from which to construct the context + preceding_text (str): Any preceding text, needed to if self.example_delimiter is needed at the beginning + + Returns: + list: All context options for the selected example with formatting + """ + context_options = example[self.choices_key] + if len(preceding_text) > 0: + if self.strip_data: + cont_del = self.continuation_delimiter.rstrip() + else: + cont_del = self.continuation_delimiter + context_options = [ + f'{self.prelimiter}{self.example_delimiter}{c}{cont_del}' + for c in context_options + ] + else: + context_options = [f'{self.prelimiter}{c}' for c in context_options] + return context_options + + def _prep_example( + self, + example: Dict, + example_idx: int, + num_fewshot: int, + prompt_string: str, + fewshot_rng: random.Random, + ) -> Dict[str, Any]: + """Prepares a single example from a HF Dataset into tokenized format. + + with prompt and fewshot examples. + + Each task consists of multiple contexts and a single, correct continuation. Will prepend fewshot examples and + prompt if present. + + Args: + example (Dict): A dictionary from the hf dataset + example_idx (int): The index of example + num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair + prompt_string (str): The prompt to prepend to all inputs + fewshot_rng (random.Random): Random number generator to use for fewshot sampling + + Returns: + Dict: Contains a dictionary with the tokenized data + """ + prompt_and_fewshot = self._generate_few_shot_prompt( + num_fewshot, example_idx, prompt_string, fewshot_rng) + ctxt = self._construct_multiple_contexts(example, prompt_and_fewshot) + tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, + example) + return tokenized_example + + def tokenize_example(self, prompt_and_fewshot: str, + context_options: List[str], + example: Dict) -> Dict[str, Any]: + """Runs text through the tokenizer and handle special cases. + + Args: + prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context + ctx (str): The specific example's derived context + example (Dict): The example as a dictionary. + + Returns: + Dict: Dictionary with the tokenized data + """ + tokenized_example = {} + preamble = self.tokenizer(prompt_and_fewshot)['input_ids'] + assert isinstance(preamble, list) + preamble = self._fix_eos_on_preamble(preamble) + encoded_contexts = [ + preamble + + # pyright: ignore[reportOperatorIssue, reportGeneralTypeIssues] + self.tokenizer(c, add_special_tokens=False)[ + 'input_ids'] # pyright: ignore[reportOperatorIssue, ] + for c in context_options + ] + continuation = example['continuation'] + if self.prefix_space: + continuation = (f' {continuation}' if + not continuation.startswith(' ') else continuation) + tokenized_continuation = self.tokenizer( + continuation, add_special_tokens=False)['input_ids'] + + tokenized_example[self.context_key] = [] + tokenized_example['continuation_indices'] = [] + tokenized_example[self.answer_key] = [] + for context in encoded_contexts: + assert isinstance(context, list) + assert isinstance(tokenized_continuation, list) + trimmed_context = trim_context(context, tokenized_continuation, + self.padding_size) + assert isinstance(trimmed_context, list) + continuation_indices = get_continuation_span( + trimmed_context, tokenized_continuation) + padded_context = make_padded_input(trimmed_context, + tokenized_continuation, + self.padding_size, + self.pad_tok_id, + self.padding_side) + tokenized_example[self.context_key].append(padded_context) + tokenized_example['continuation_indices'].append( + continuation_indices) + tokenized_example[self.answer_key].append(tokenized_continuation) + + tokenized_example['gold'] = example['gold'] + return tokenized_example + + +class InContextLearningCodeEvalDataset(InContextLearningDataset): + """A dataset that constructs batches for in-context learning code. + + evaluation. + + The input format is expected to be a jsonl file with the following fields: + + - task_id: Label of given task + - prompt: The code snippet that must be completed + - entry_point: The entry to the function/code snippet to generate + - canonical_solution: Working solution + - test: The checker code that will run to completion if the code generation is valid and otherwise throw assertion + - test_inputs: List of test inputs + - test_outputs: List of test outputs + - language: The language of the code snippet + + Each batch then consists of the following the structure + + - input_ids: Input tensor batch x seqlen x num tokens + - mode: Indicates to the model that this is an ICL task and may rely on a custom code path to properly update metrics + - mode: Always set to 'generate' + - labels: Exact solution for the coding problem + - prompts: Prompt for the task + - entry_points: List of entry points + - test_inputs: List of test inputs + - test_outputs: List of test outputs + - languages: List of languages + - pass_at_k: Passed value for pass_at_k + - generation_kwargs: Dictionary of kwargs needed for generation. Includes the following, which will be individually overwritten + by keys in generation_kwargs if set (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + for more details): + + - pad_token_id: ID for padding token, derived automatically + - num_beams: How many beams to search for generations, default set to 1 + - do_sample: Determines whether model is sampling or greedily decoding. Always set to True + - use_cache: Whether or not to use past key values to speed up sampling. Always set to True + + Additional Args: + generations_per_sample (int) (defaults to 1): The number of independently computed returned sequences for each element in the batch + pass_at_k (int) (defaults to 1): k for how many chances the model gets to write passing code + """ + + def __init__( + self, + generations_per_sample: int, + pass_at_k: Union[int, list[int]] = 1, + *args: Any, + **kwargs: Any, + ): + if isinstance(pass_at_k, int): + pass_at_k = [pass_at_k] + if generations_per_sample < max(pass_at_k): + raise ValueError( + f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.' + ) + batch_mapping = { + 'input_ids': 'prompt', + 'prompts': 'prompt_text', + 'tests': 'test', + 'labels': 'canonical_solution', + 'entry_points': 'entry_point', + 'test_inputs': 'test_inputs', + 'test_outputs': 'test_outputs', + 'languages': 'language', + 'sample_id': 'sample_id', + } + # Linting complains if these are not set in init + self.max_prompt_length = 0 + self.max_answer_length = 0 + static_keys = [ + 'mode', + 'pass_at_k', + 'generation_kwargs', + 'generations_per_sample', + 'dataset_size', + ] + list_keys = [ + 'prompts', + 'tests', + 'entry_points', + 'test_inputs', + 'test_outputs', + 'languages', + 'labels', + 'sample_id', + ] + tensor_keys = ['input_ids', 'attention_mask'] + super().__init__( + context_key='prompt', + answer_key='canonical_solution', + strip_dataset=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + tokenize_labels=False, + padding_side='left', + batch_mapping=batch_mapping, + *args, + **kwargs, + ) + self._set_max_prompt_and_answer_lengths() + if self.max_seq_len < self.max_prompt_length: + log.warning(f'`max_seq_len` {self.max_seq_len} was less than `max_prompt_len`: {self.max_prompt_length}' \ + + ' setting `max_seq_len`=`max_prompt_len`') + self.max_seq_len = self.max_prompt_length + dataset_size = len(self.dataset) + self.dataset = self.dataset.map(self._trim_padding) + self.dataset = self.repeat_dataset(self.dataset, generations_per_sample) + + if self.max_answer_length < self.max_seq_len - self.max_prompt_length: + max_new_tokens = self.max_answer_length + else: + max_new_tokens = self.max_seq_len - self.max_prompt_length + + self.base_batch = { + 'input_ids': [], + 'mode': 'generate', + 'labels': [], + 'prompts': [], + 'tests': [], + 'entry_points': [], + 'test_inputs': [], + 'test_outputs': [], + 'languages': [], + 'pass_at_k': pass_at_k, + 'generation_kwargs': { + 'pad_token_id': self.pad_tok_id, + 'num_beams': 1, # single beam + 'do_sample': True, + 'temperature': 0.2, # good default for code + 'use_cache': True, + 'eos_token_id': self.tokenizer.eos_token_id, + 'max_new_tokens': max(max_new_tokens, 1) + }, + 'sample_id': [], + 'pass_at_k': list(pass_at_k), + 'generations_per_sample': generations_per_sample, + 'dataset_size': dataset_size, + } + if 'generation_kwargs' in kwargs: + self.update_generation_kwargs(kwargs['generation_kwargs']) + + def repeat_dataset(self, dataset: HFDataset, repetitions: int) -> HFDataset: + + def _repeat_dataset(): + for i, sample in enumerate(dataset): + for _ in range(repetitions): + assert isinstance(sample, dict) + yield {'sample_id': i, **sample} + + from datasets import \ + Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] + + repeated_dataset = HFDataset.from_generator(_repeat_dataset) + assert isinstance(repeated_dataset, HFDataset) + return repeated_dataset + + def _set_max_prompt_and_answer_lengths(self): + """Iterates through the dataset and finds the maximum prompt length and. + + sequence lengths. + + Returns: + None + """ + max_prompt_length = 0 + max_answer_length = 0 + for example in self.dataset: + assert isinstance(example, Dict) + unpadded_example = [ + token for token in example[self.context_key] + if token != self.pad_tok_id + ] + max_prompt_length = max(max_prompt_length, len(unpadded_example)) + + tokenized_answer = self.tokenizer( + example['canonical_solution'], + add_special_tokens=False)['input_ids'] + assert isinstance(tokenized_answer, list) + len_tokenized_answer = len(tokenized_answer) + max_answer_length = max(max_answer_length, len_tokenized_answer) + + self.max_prompt_length = max_prompt_length + self.max_answer_length = max_answer_length + _MAX_ANSWER_BUFFER_LENGTH + + def _trim_padding(self, example: Dict): + """Adjusts padding to the maximum prompt length rather than max_seq_len. + + Needs to be done after the dataset has been processed because we don't + know the maximum prompt length until after we've tokenized it. + + Returns: + dataset: A HuggingFace Dataset with different padding lengths for example[self.context_key] + """ + # Remove padding tokens applied during tokenization + unpadded_prompt = [ + token for token in example[self.context_key] + if token != self.pad_tok_id + ] + # Reapply padding only to max_prompt_length + full_prompt = trim_context(unpadded_prompt, [], self.max_prompt_length) + padded_context = make_padded_input(full_prompt, [], + self.max_prompt_length, + self.pad_tok_id, self.padding_side) + + example[self.context_key] = padded_context + return example + + def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, + example: Dict) -> Dict[str, Any]: + """Adds extra code task details to the example dictionary. + + See InContextLearningDataset for more details + """ + tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, + example) + tokenized_example['prompt_text'] = example['prompt'] + tokenized_example['task_id'] = example['task_id'] + tokenized_example['canonical_solution'] = example['canonical_solution'] + tokenized_example['test'] = example['test'] + tokenized_example['entry_point'] = example['entry_point'] + tokenized_example['test_inputs'] = example['test_inputs'] + tokenized_example['test_outputs'] = example['test_outputs'] + tokenized_example['language'] = example['language'] + return tokenized_example + + +def build_icl_dataloader( + icl_task_type: str, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str, # e.g. '' + hf_loading_vars: Dict, + hf_parsing_map: Dict, + destination_path: str, + prelimiter: str, # e.g. 'Question: ' + cot_delimiter: str, # e.g. ' ### ' + fewshot_random_seed: int, + pass_at_k: int, + generations_per_sample: int, + generation_kwargs: Dict, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True) -> DataSpec: + """Factory method that builds the specific dataset for the specified. + + icl_task_type. See documentation for `get_icl_task_dataloader` for argument + documentation. + + When writing a dataset for a new task, here you will need to: + 1. add the dataset to the factory and choose an appropriate string + 2. set the batch size for that task (see InContextLearningMultipleChoiceTaskDataset for why + this might be different) + 3. set the `split_batch` function if necessary + """ + if icl_task_type == 'multiple_choice': + dataset = InContextLearningMultipleChoiceTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) + batch_size = max(dataset.num_choices, batch_size) + effective_batchsize = batch_size // dataset.num_choices + elif icl_task_type == 'schema': + dataset = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) + batch_size = max(dataset.num_choices, batch_size) + effective_batchsize = batch_size // dataset.num_choices + elif icl_task_type == 'language_modeling': + dataset = InContextLearningLMTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + ) + effective_batchsize = batch_size + elif icl_task_type == 'generation_task_with_answers' or icl_task_type == 'question_answering': + if icl_task_type == 'question_answering': + warnings.warn( + VersionedDeprecationWarning( + "ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'", + 'v0.9.0')) + dataset = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + cot_delimiter=cot_delimiter, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + generation_kwargs=generation_kwargs, + ) + effective_batchsize = batch_size + elif icl_task_type == 'code_evaluation': + warnings.warn( + VersionedDeprecationWarning( + "ICL task type 'code_evaluation' is deprecated and will no longer be supported. ", + 'v0.9.0')) + dataset = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + fewshot_random_seed=fewshot_random_seed, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + generation_kwargs=generation_kwargs, + ) + effective_batchsize = batch_size + else: + raise Exception(f'Unrecognized ICL task type: {icl_task_type}') + + sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False) + + split_batch = None + if isinstance( + dataset, + ( + InContextLearningMultipleChoiceTaskDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningCodeEvalDataset, + ), + ): + split_batch = dataset.split_batch + + return DataSpec( + DataLoader( + dataset, + batch_size=effective_batchsize, + sampler=sampler, + collate_fn=dataset.collate_fn, + ), + device_transforms=None, + get_num_samples_in_batch=dataset.get_num_samples_in_batch, + split_batch=split_batch, + ) + + +def partition_dataset_by_category(dataset_uri: str, destination_path: str, + hf_loading_vars: Dict, + hf_parsing_map: Dict) -> Dict[str, str]: + """If has_categories is enabled, we partition the dataset into a separate. + + dataset for each category value in the data and write each partition to a + local file. + + Args: + dataset_uri (str): Location of dataset. + destination_path (str): Base destination path, we will write a separate partition off this URI for each category. + + Raises: + MissingConditionalImportError: If datasets not installed raise exception. + Exception: If 'category' key missing from dataset, raise exception. + + Returns: + Dict[str, str]: Mapping of category names to partitioned dataset local files names. + """ + if dataset_uri.startswith('hf://'): + dataset_uri = dataset_uri.replace('hf://', '') + dataset = load_dataset(dataset_uri, **hf_loading_vars) + assert isinstance(dataset, HFDataset) or isinstance( + dataset, IterableDataset) + if hf_parsing_map: + dataset_parsing_func = lambda example: { + k: ' '.join([str(example[col]) for col in v]) + for k, v in hf_parsing_map.items() + } + assert hasattr(dataset, 'column_names') + dataset = dataset.map(dataset_parsing_func, + remove_columns=dataset.column_names) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', + data_files=destination_path, + split='train', + streaming=False) + assert isinstance(dataset, HFDataset) or isinstance(dataset, + IterableDataset) + assert hasattr(dataset, 'features') + assert dataset.features is not None + if 'category' not in dataset.features.keys(): + raise Exception(f"""Attempted to partition dataset by `category` \ + but it doesn't have a `category` key. \ + Got keys: {str(list(dataset.features.keys()))}""") + categories = sorted( + set(dataset['category'] + )) # pyright: ignore[reportIndexIssue, reportGeneralTypeIssues] + output_files = {} + for cat in categories: + path = destination_path.split('/') + cat_dest = '/'.join(path[:-1]) + f'/{cat}_{path[-1]}' + tmp_path_to_broadcast = str(os.path.abspath(cat_dest)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + if dist.get_local_rank() == 0: + subset = [ + l for l in dataset if + l['category'] == cat # pyright: ignore[reportGeneralTypeIssues] + ] # pyright: ignore[reportArgumentType, reportCallIssue] + with open(gathered_paths[0], 'w', encoding='utf8') as f: + for l in subset: + f.write(json.dumps(l, ensure_ascii=False) + '\n') + output_files[cat] = cat_dest + return output_files + + +def get_icl_task_dataloader( + icl_task_type: str, + dataset_uri: str, + tokenizer: Union[transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast], + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str = '', + destination_path: str = '', + question_prelimiter: str = '', # e.g. 'Question: ' + fewshot_random_seed: int = 1234, + pass_at_k: int = 1, + generations_per_sample: int = 1, + cot_delimiter: str = '', + has_categories: bool = False, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]: + r"""Constructs a dataloader (or dataloaders if has_categories is True) + + capable of evaluating LLMs on in-context learning language modeling tasks, + for example LAMBADA. An example usage is below: + + .. testsetup:: + + import transformers + from composer.models import HuggingFaceModel + from composer.trainer import Trainer + dataset_uri = "/tmp/dataset_uri.jsonl" + dataset = RandomTextClassificationDataset(size=16, use_keys=True) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) + hf_model, tokenizer = HuggingFaceModel.hf_from_composer_checkpoint('composer-hf-checkpoint.pt') + # At this point, hf_model is randomly initialized + composer_model = HuggingFaceModel(hf_model, hf_tokenizer) + + Example: + + .. testcode:: + + + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri, + tokenizer, + batch_size=2, + max_seq_len=2048, + pad_tok_id=tokenizer.pad_token_id, + num_fewshot=10, + prompt_string='translate english to french', + example_delimiter='\\n', + continuation_delimiter='' + ) + eval_evaluator = Evaluator( + label="lambada", + dataloader=dl, + metric_names=['InContextLearningLMAccuracy'] + ) + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_evaluator, + optimizers=optimizer, + max_duration="1ep", + ) + + Args: + icl_task_type (str): Name of icl_task type. One of ['multiple_choice', 'schema', 'language_modeling', 'generation_task_with_answers', 'code_evaluation'] + dataset_uri (str): A local path, a remote path beginning with ``s3://`` or another backend, or a HuggingFace dataset uri prepended with ``hf://``. + Alternate backends must be supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + A local dataset must consist of rows of JSON data points with task dependant fields. + The default keys expected are "context" and "answer". + tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. + batch_size (int): Size of a batch used for eval + max_seq_len (int): The maximum sequence length supported by the model. + pad_tok_id (int): The special token used for padding batches. + num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. + prompt_string (str, default = ''): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). + example_delimiter (str, default = '\\n'): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. + continuation_delimiter: (str, default = ' '): Separator inserted between context and answer in each example (e.g. '\\nA: '). + destination_path: (str, default = ''): This is the local file where remote datasets will be saved. + question_prelimiter: (str, default = ''): Text to be prepended before each context, including few shot examples (e.g. "Question: "). + fewshot_random_seed (int, default = 1234): Random seed to use for fewshot sampling + pass_at_k (int): k for how many chances the model gets to write passing code. + generations_per_sample (int): How many outputs to generate per prompt. Passed in generation_kwargs under "num_return_sequences" and overwritten by generation_kwargs dict. + cot_delimiter (str): Delimiter to place between chain of thoughts and continuations. + has_categories: (bool): If ``True``, we will search the dataset file for a category key, and partition the dataset into a separate dataloader for each category occurring in the data. + hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. + generation_kwargs (Dict, default = None): A dictionary containing keyword arguments to be passed along to the model's generate function. Overwrites any previously specified generation + keyword args in this function (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + for more details) + early_stopping (List, default = None): A list of strings that, when found in a model's output, will be treated as a stopping criteria at metric computation time. + Used in generation tasks with CoT + do_normalization (bool, default = True): Whether or not to normalize the outputs and labels in InContextLearningGenerationTaskWithAnswersDataset. Only used in generation tasks. + + Returns: + DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. + """ + if hf_loading_vars is None: + hf_loading_vars = {} + if hf_parsing_map is None: + hf_parsing_map = {} + if generation_kwargs is None: + generation_kwargs = {} + if early_stopping_criteria is None: + early_stopping_criteria = [] + + if has_categories: + result_dls = {} + output_files = partition_dataset_by_category(dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map) + categories = sorted(output_files.keys()) + for category in categories: + partition_uri = output_files[category] + result_dls[category] = build_icl_dataloader( + icl_task_type=icl_task_type, + dataset_uri=partition_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=partition_uri + '_tmp', + prelimiter=question_prelimiter, + cot_delimiter=cot_delimiter, + fewshot_random_seed=fewshot_random_seed, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + ) + return result_dls + else: + return build_icl_dataloader( + icl_task_type=icl_task_type, + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=question_prelimiter, + cot_delimiter=cot_delimiter, + fewshot_random_seed=fewshot_random_seed, + pass_at_k=pass_at_k, + generations_per_sample=generations_per_sample, + generation_kwargs=generation_kwargs, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization, + ) diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py new file mode 100644 index 0000000000..6433e7cb56 --- /dev/null +++ b/llmfoundry/eval/datasets/utils.py @@ -0,0 +1,285 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility and helper functions for datasets.""" +from __future__ import annotations + +import logging +import random +from typing import Any, Dict, List, Optional, Set + +import torch +import transformers + +__all__ = [ + 'MultiTokenEOSCriteria', + 'strip_data', + 'tokenizer_needs_prefix_space', + 'trim_context', + 'get_continuation_span', + 'make_padded_input', + 'convert_tokens_to_tensors', + 'get_fewshot_sample_idxs', + 'stop_sequences_criteria', +] + +log = logging.getLogger(__name__) + + +def strip_data(example: Dict) -> Dict: + """Remove white space from the begging and end of string values in a. + + dictionary. + + Args: + example: Dictionary to be stripped + + Returns: + dict: The same dictionary with .strip() applied to any value in the dict that is a string + """ + return { + k: v.strip() if isinstance(v, str) else v for k, v in example.items() + } + + +def tokenizer_needs_prefix_space( + tokenizer: transformers.PreTrainedTokenizerBase) -> bool: + """Test for whether a prefix space is needed before the continuation. + + Sentencepiece tokenization should not have a prefix space, but gpt2 style + BPE should. + + Args: + tokenizer: Tokenizer to test + + Returns: + bool: Whether or not the tokenizer needs a prefix space + """ + test_tokens = tokenizer(' a', add_special_tokens=False)['input_ids'] + assert isinstance(test_tokens, list) + return len(test_tokens) == 1 + + +def trim_context(context_enc: List, continuation_enc: List, + max_seq_len: int) -> List: + """Trims a list of tokens down to `max_seq_len` if the length of the list. + + plus the continuation is more than `max_seq_len`. It will always trim tokens + from the left, i.e. tokens at the beginning of the context will be removed. + + Args: + context_enc (list): List of tokens in the context + continuation_enc (list): List of tokens in the continuation + max_seq_len (int): Maximum length the model can ingest + + Returns: + list: The encoded context trimmed from the left + """ + if len(continuation_enc) + len(context_enc) > max_seq_len: + context_max_subseq_len = max_seq_len - len(continuation_enc) + + if context_max_subseq_len < 0: + # can't support continuations which are longer than the max seq len + raise Exception( + f'Dataset included continuation longer than the max seq len') + + # clip from the end + context_enc = context_enc[-(context_max_subseq_len):] + return context_enc + + +def get_continuation_span(context_enc: List, + continuation_enc: List) -> torch.Tensor: + """Gets the list of indices of the continuation tokens for language. + + modeling. + + or generation tasks. + + Args: + context_enc (list): List of context tokens + continuation_enc (list): List of continuation tokens + + Returns: + torch.tensor: A tensor containing indices corresponding to continuation tokens + """ + return torch.tensor( + range(len(context_enc), + len(context_enc) + len(continuation_enc))) + + +def make_padded_input(context_enc: List, + continuation_enc: List, + max_seq_len: int, + pad_tok_id: int, + padding_side: str = 'right') -> torch.Tensor: + """Takes an encoded context and continuation and clips the beginning of the. + + context if they're too long. Adds the padding token to the specified side. + + Args: + context_enc (List): The encoded input to the model + continuation_enc (List): The encoded desired output for the example + max_seq_list (int): Maximum length sequences can be + pad_tok_id (int): The token id we pad with + padding_side (str): Which side to pad the context on. Can be 'right' or 'left + + Returns: + input (torch.tensor): The padded and encoded context + continuation_span (torch.tensor): The _inclusive_ range of indices corresponding to the continuation + """ + inp = torch.tensor( + (context_enc + continuation_enc), + dtype=torch.long, + ) + (inp_len,) = inp.shape + + # Sometimes tokenizers that have neither a pad_tok_id or eos_tok_id will pass None in as the padding + # token and cause errors + if not isinstance(pad_tok_id, int): + raise ValueError( + f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead' + ) + # pad length from seq to padding_length + if padding_side == 'right': + inp = torch.cat( + [ + inp, # [seq] + torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]), + ], + dim=0, + ) + elif padding_side == 'left': + inp = torch.cat( + [ + torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]), + inp, # [seq] + ], + dim=0, + ) + else: + raise ValueError( + f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'" + ) + + return inp + + +def convert_tokens_to_tensors(batch: Dict, + tokenize_labels: bool) -> Dict[str, Any]: + """HF Datasets converts tensors into lists when we store them, and we don't. + + want to use `type='torch'` because some content in the dataset, like + generation args or single ints, should not be converted. + + Here, we convert those lists of tokens back into tensors in order to feed them into the model. + + Args: + batch (dict): A dictionary of batched inputs + tokenize_labels (bool): Whether or not the labels are tokenized (and need to be stacked) + + Returns: + dict: The batch with torch tensors in the corresponding keys instead of lists of lists + """ + batch['input_ids'] = torch.stack(list(map(torch.tensor, + batch['input_ids']))) + if tokenize_labels: + batch['labels'] = torch.stack(list(map(torch.tensor, batch['labels']))) + batch['continuation_indices'] = list( + map(torch.tensor, batch['continuation_indices'])) + return batch + + +def get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, + example_idx: int, rng: random.Random) -> Set[int]: + """Samples indices without replacement. If num_fewshot exceeds the number. + + of unique examples in the dataset, then we will have fewer than num_fewshot examples in context. + + Args: + dataset_size (int): Length of the dataset + num_fewshot (int): Number of examples to prepend + example_idx (int): Current example's index (excluded from fewshot choices) + rng (random.Random): RNG for repeatable sample selection + + Returns: + list: Indices of the examples chosen for fewshot selection + """ + num_fewshot = min(dataset_size - 1, num_fewshot) + fewshot_idxs = set(rng.sample(range(0, dataset_size), num_fewshot)) + + if example_idx in fewshot_idxs: + fewshot_idxs.remove(example_idx) + if len(fewshot_idxs) >= dataset_size - 1: + return fewshot_idxs + + replacement_sample = rng.choice(range(0, dataset_size)) + while replacement_sample in fewshot_idxs or replacement_sample == example_idx: + replacement_sample = rng.choice(range(0, dataset_size)) + fewshot_idxs.add(replacement_sample) + return fewshot_idxs + + +class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence. + + Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649 + """ + + def __init__( + self, + stop_sequence: str, + tokenizer: transformers.PreTrainedTokenizerBase, + batch_size: int, + ) -> None: + self.done_tracker = [False] * batch_size + self.stop_sequence = stop_sequence + self.stop_sequence_ids = tokenizer.encode(stop_sequence, + add_special_tokens=False) + + # sentence piece tokenizers add a superfluous underline token before string-initial \n + # that throws off our calculation of the stop sequence length + # so we remove any token ids that produce empty strings + self.stop_sequence_ids = [ + id for id in self.stop_sequence_ids if tokenizer.decode(id) != '' + ] + + # we look back for 1 more token than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + self.stop_sequence_id_len = len(self.stop_sequence_ids) + 1 + self.tokenizer = tokenizer + + def __call__(self, + input_ids: torch.LongTensor, + scores: Optional[torch.FloatTensor] = None, + **kwargs: Dict[str, Any]) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + for i, done in enumerate(self.done_tracker): + if i >= len(lookback_tokens_batch): + # The last batch of a dataset may be smaller than `batch_size` + # Automatically set those indices in the done_tracker to True + # since those indices don't show up in the current batch + self.done_tracker[i] = True + break + elif not done: + self.done_tracker[ + i] = self.stop_sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + +def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizerBase, + stop_sequences: List[str], + batch_size: int, +) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList([ + *[ + MultiTokenEOSCriteria(sequence, tokenizer, batch_size) + for sequence in stop_sequences + ], + ]) diff --git a/llmfoundry/eval/metrics/__init__.py b/llmfoundry/eval/metrics/__init__.py new file mode 100644 index 0000000000..079439da59 --- /dev/null +++ b/llmfoundry/eval/metrics/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""A collection of common torchmetrics.""" + +from llmfoundry.eval.metrics.nlp import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, InContextLearningMetric, + InContextLearningMultipleChoiceAccuracy) + +__all__ = [ + 'InContextLearningMetric', + 'InContextLearningLMAccuracy', + 'InContextLearningMultipleChoiceAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', + 'InContextLearningCodeEvalAccuracy', + 'InContextLearningLMExpectedCalibrationError', + 'InContextLearningMCExpectedCalibrationError', +] diff --git a/llmfoundry/eval/metrics/nlp.py b/llmfoundry/eval/metrics/nlp.py new file mode 100644 index 0000000000..f5a50721e3 --- /dev/null +++ b/llmfoundry/eval/metrics/nlp.py @@ -0,0 +1,730 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""A collection of common torchmetrics for NLP tasks.""" + +import copy +import functools +import logging +import os +import re +import string +import warnings +from typing import Any, Callable, Dict, List + +import numpy as np +import torch +from composer.utils import dist +from composer.utils.eval_client import (EvalClient, LambdaEvalClient, + LocalEvalClient, + MosaicMLLambdaEvalClient) +from torch import Tensor +from torch.nn import functional as F +from torchmetrics import Metric + +log = logging.getLogger(__name__) + +__all__ = [ + 'InContextLearningMetric', + 'InContextLearningLMAccuracy', + 'InContextLearningMultipleChoiceAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', + 'InContextLearningCodeEvalAccuracy', + 'InContextLearningLMExpectedCalibrationError', + 'InContextLearningMCExpectedCalibrationError', +] + + +class InContextLearningMetric(Metric): + + def __init__(self, *args, **kwargs): # pyright: ignore + super().__init__(*args, **kwargs) + self.needs_batch = True + + def _wrap_update(self, update: Callable) -> Callable: + """Overwrite default _wrap_update to return result of update(). + + Torch metrics wraps update with following wrapped_func but explicitly + does not return the value. In general, torchmetrics update() does not + return a value, but we want to in order to pass it on to + state.metric_outputs. + """ + + @functools.wraps(update) + def wrapped_func(*args: Any, **kwargs: Any) -> None: + self._computed = None + self._update_count += 1 + with torch.set_grad_enabled(self._enable_grad): + try: + update_result = update(*args, **kwargs) + except RuntimeError as err: + if 'Expected all tensors to be on' in str(err): + raise RuntimeError( + 'Encountered different devices in metric calculation (see stacktrace for details).' + \ + ' This could be due to the metric class not being on the same device as input.' + \ + f' Instead of `metric={self.__class__.__name__}(...)` try to do' + \ + f' `metric={self.__class__.__name__}(...).to(device)` where' + \ + ' device corresponds to the device of the input.', + ) from err + raise err + + if self.compute_on_cpu: + self._move_list_states_to_cpu() + return update_result + + return wrapped_func + + def update( + self, + batch: dict, + outputs: torch.Tensor, + labels: torch.Tensor, + ): + """Abstract interface for computing an in-context learning metrics. + + The `output_logits` argument is deprecated and will be removed in v0.21 while it's functionality will + be moved to `outputs`. + + Args: + batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed + to compute the metric. + output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` + labels (torch.Tensor): The correct outputs. + + Raises: + NotImplementedError: Abstract method must be implemented by subclasses + """ + raise NotImplementedError + + +class InContextLearningGenerationExactMatchAccuracy(InContextLearningMetric): + r"""Computes exact match for in-context learning generation tasks. + + ICL generation tasks consist of some number of prompted generation tasks with correct answers + followed by a test task where the model must correctly produce one of a number of valid answers. + + For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. + + Context: `Question: Who was president of the United States in 2012?\nAnswer: Barack Obama\nQuestion: Is water wet?\nAnswer: ` + Answers: [`yes`] + + The model will be expected to correctly produce one of the answers, following some optional normalization. + + Adds metric state variables: + correct (float): The number of instances where the prediction was a prefix for any of the answer aliases. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', + default=torch.tensor(0.), + dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') + self.metric_result_dict = { + 'cleaned_output': [], + 'original_label': [], + 'cleaned_label': [], + 'result': [], + } + + def normalize_answer(self, answer: str): + """Lower text and remove punctuation, articles and extra whitespace. + + Copied from https://github.com/mandarjoshi90/triviaqa/blob/master/evaluation/triviaqa_evaluation.py + """ + + def remove_articles(text: str) -> str: + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text: str) -> str: + return ' '.join(text.split()) + + def handle_punc(text: str) -> str: + exclude = set(string.punctuation + + ''.join([u'‘', u'’', u'´', u'`'])) + return ''.join(ch if ch not in exclude else ' ' for ch in text) + + def lower(text: str) -> str: + return text.lower() + + def replace_underscore(text: str) -> str: + return text.replace('_', ' ') + + return white_space_fix( + remove_articles(handle_punc(lower( + replace_underscore(answer))))).strip() + + def update( + self, + batch: Dict[str, Any], + outputs: List[str], + labels: List[List[str]], + ): + cot_delimiter = batch.get('cot_delimiter', '') + do_normalization = batch.get('do_normalization', True) + stopping_criteria = batch.get('stopping_criteria', None) + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for sample_output, sample_labels in zip(outputs, labels): + final_answer = sample_output + + if stopping_criteria is not None and len(stopping_criteria) > 0: + final_answer = re.split('|'.join(stopping_criteria), + final_answer)[0] + + if cot_delimiter is not None and len(cot_delimiter) > 0: + final_answer = final_answer.split(cot_delimiter)[-1] + + if do_normalization: + cleaned_final_answer = self.normalize_answer(final_answer) + cleaned_sample_labels = { + self.normalize_answer(label) for label in sample_labels + } + else: + # even if normalization is off, we should still strip leading/trailing whitespaces + cleaned_final_answer = final_answer.strip() + cleaned_sample_labels = { + sample_label.strip() for sample_label in sample_labels + } + metric_result_dict['original_label'].append(sample_labels) + metric_result_dict['cleaned_output'].append(cleaned_final_answer) + metric_result_dict['cleaned_label'].append(cleaned_sample_labels) + + if any( + cleaned_final_answer.startswith(label) + for label in cleaned_sample_labels): + self.correct += torch.tensor(1.0) + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + self.total += torch.tensor(1.0) + + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct / self.total + + +class InContextLearningLMAccuracy(InContextLearningMetric): + r"""Computes accuracy for In-context learning language modeling tasks. + + ICL LM tasks consist of some number of example language modeling tasks (referred to as the 'context'), followed by a test task where the model must correctly predict all the tokens + following tokens in some passage (referred to as the 'continuation'). + + For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. Note: it doesn't matter + whether the model correctly predicts the context tokens. + + Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` + Continuation: `green` + + Adds metric state variables: + correct (float): The number of instances where the prediction masked the target. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', + default=torch.tensor(0.), + dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') + self.metric_result_dict = { + 'context': [], + 'label': [], + 'output': [], + 'result': [] + } + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_pred = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1).argmax(dim=-1) + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + + metric_result_dict['context'].append( + batch['input_ids'][batch_idx][:cont_idx[0]]) + metric_result_dict['label'].append(cont_tok_targ) + metric_result_dict['output'].append(cont_tok_pred) + + correct = (cont_tok_pred == cont_tok_targ).all().int() + self.correct += correct + metric_result_dict['result'].append(int(correct)) + + self.total += torch.tensor(1.0) + + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct / self.total + + +class InContextLearningMultipleChoiceAccuracy(InContextLearningMetric): + r"""Computes accuracy for In-context learning multiple choice tasks. + + ICL MC tasks consists of a series of questions with some number of possible choices (only one of which can be correct). + At inference time each possible choice is given to the model as a separate input and the one for which the model assigns + the lowest perplexity to the choice is considered the model's choice. The model is correct if it "chooses" the right answer. + + Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` + Continuation: `green` + + Adds metric state variables: + correct (float): The number of instances where the prediction masked the target. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', + default=torch.tensor(0.0), + dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.metric_result_dict = { + 'context': [], + 'correct_choice': [], + 'correct_choice_idx': [], + 'selected_choice': [], + 'selected_choice_idx': [], + 'all_choices': [], + 'result': [], + } + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + perplexities = [] + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + # continuation indices refer to indices in the original input's token space + cont_tok_logits = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1) + # labels have been shifted left by one index, so the cont_idx needs to be shifted as well. + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + cross_entropy = F.cross_entropy(cont_tok_logits, cont_tok_targ) + perplexity = torch.exp(cross_entropy) + perplexities.append(perplexity) + + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for (start, end), gold_idx in zip(batch['choice_groupings'], + batch['gold_indices']): + subset = perplexities[start:end] + idx_min = subset.index(min(subset)) + + if idx_min == gold_idx: + self.correct += torch.tensor(1.0) + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + question = batch['input_ids'][ + start][:batch['continuation_indices'][start][0]] + + correct_choice = batch['input_ids'][start:end][gold_idx][ + batch['continuation_indices'][start:end][gold_idx][0]: + batch['continuation_indices'][start:end][gold_idx][-1] + 1] + selected_choice = batch['input_ids'][start:end][idx_min][ + batch['continuation_indices'][start:end][idx_min][0]: + batch['continuation_indices'][start:end][idx_min][-1] + 1] + metric_result_dict['context'].append(question) + metric_result_dict['correct_choice'].append(correct_choice) + metric_result_dict['correct_choice_idx'].append(gold_idx) + metric_result_dict['selected_choice'].append(selected_choice) + metric_result_dict['selected_choice_idx'].append(idx_min) + all_choices = batch['input_ids'][start:end] + # Unpads the choices. Necessary in case different choices have different token lengths. + if 'attention_mask' in batch: + all_choices_list = [ + choice[batch['attention_mask'][i]] + for i, choice in enumerate(all_choices) + ] + metric_result_dict['all_choices'].append(all_choices_list) + + self.total += torch.tensor(1.0) + + # Don't return all_choices if we didn't fill it up (i.e. didn't use causal lms) + if metric_result_dict['all_choices'] == []: + metric_result_dict.pop('all_choices') + + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct.float() / self.total + + +class InContextLearningCodeEvalAccuracy(InContextLearningMetric): + r"""Computes accuracy for In-context learning (ICL) code evaluation tasks. + + ICL code eval tasks consist of some number of example code eval tasks (referred to as the 'context'), followed by a test task where the model must + complete the code, where we term the code completion a 'continuation'. + + In each case, the model constructs a given number of continuations (termed pass@K for K continuations), and each continuation is run against a set of test cases. The model is considered + correct if at least one of the proposed continuations passes all the test cases. + + Runs on AWS Lambdas by default. + + Adds metric state variables: + correct (float): The number of instances where the predictions passed all the test cases. + total (float): The number of total instances that were predicted. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self._initialized = False + self.dataset_size = 0 + self.pass_at_k = [] + self.num_generations = 0 + self.eval_device = os.environ.get('CODE_EVAL_DEVICE', None) + if self.eval_device is not None: + self.eval_device = self.eval_device.upper() + self.metric_result_dict = { + 'context': [], + 'output': [], + 'result': [], + 'sample_id': [] + } + + def get_client(self) -> EvalClient: + """Returns a client for the appropriate remote platform.""" + client = None + if self.eval_device == 'LOCAL': + warnings.warn( + 'Running code eval locally may be insecure. Please set environment variable CODE_EVAL_DEVICE ' + + + 'to LAMBDA to run on remote. To use Lambdas, spin up your instance that checks code, set the URL as ' + + 'CODE_EVAL_URL and the API key as CODE_EVAL_APIKEY.') + log.debug('Running code eval locally.') + client = LocalEvalClient() + elif self.eval_device == 'LAMBDA': + client = LambdaEvalClient() + elif self.eval_device == 'MOSAICML': + client = MosaicMLLambdaEvalClient() + elif self.eval_device is None: + raise ValueError( + 'Attempting to use InContextLearningCodeEvalAccuracy but environment ' + + + 'variable `CODE_EVAL_DEVICE` is not set. Please set it to `CODE_EVAL_DEVICE` ' + + + 'to one of `LOCAL` (for unsafe local eval), `LAMBDA` (for AWS lambda ' + + 'evaluation), or `MOSAICML` (for lambda eval through MAPI).') + else: + raise ValueError( + 'Environment variable `CODE_EVAL_DEVICE` must be one of `LOCAL`, ' + + f'`LAMBDA`, or `MOSAICML` but got {self.eval_device}.') + + return client + + def estimator(self, n: int, c: int, k: int) -> float: + """Computes the pass@k metric. + + Given the number of generated samples, n, the number of correct samples, c, and the k of interest, + this function calculates pass@k as 1 - comb(n - c, k) / comb(n, k) as per the definition of + pass@k in the HumanEval paper (https://arxiv.org/abs/2107.03374) and it's associated implementation: + https://github.com/openai/human-eval. + """ + if n - c < k: + return 1.0 + return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1))) + + def _initialize_state(self, batch: dict[str, Any]): + device = batch['input_ids'].device + self.dataset_size = batch['dataset_size'] + self.pass_at_k = batch['pass_at_k'] + self.num_generations = batch['generations_per_sample'] + + # We need to defer the accumulator initialization because it depends on dataset size + self.add_state('correct', + default=torch.zeros(self.dataset_size, device=device), + dist_reduce_fx='sum') + self.add_state('total', + default=torch.zeros(self.dataset_size, device=device), + dist_reduce_fx='sum') + dist.barrier() + self._initialized = True + + def update(self, batch: Dict[str, Any], outputs: List[str], + labels: List[str]): + """Updates the pass@k accuracy of code generation. + + Given a batch of prompts, test cases, and code generations, evaluates the code generations + against the test cases and augments the pass@k accuracy of the batch to the values so far. + + Args: + batch (Dict[str, Any]): A batch of data produced by the InContextLearningCodeEvalDataset, with + the prompt, test cases, and entry points. This will be a dictionary that must have the following + arguments: + { + 'prompts': List[str], + 'test_inputs': List[List[str]], + 'test_outputs': List[List[str]], + 'entry_points': List[str], + 'languages': List[str], + 'generation_kwargs': Dict[str, Any] + } + outputs (List[str]): A list of code generations in the format of HF generate with beam search, + which is the a list of strings in groups of beam_size e.g. for beam size 2 and batch size 2, the list + will be of the format [prompt 1 gen 1, prompt 1 gen 2, prompt 2 gen 1, prompt 2 gen 2] + labels (List[str]): A list of the correct code generations, for compatibility with existing HF generate + functionalities. This is not used. + """ + if not self._initialized: + self._initialize_state(batch) + + del labels # never used + client = self.get_client() + + metric_result_dict = copy.deepcopy(self.metric_result_dict) + for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip( + batch['sample_id'], outputs, batch['prompts'], + batch['test_inputs'], batch['test_outputs'], + batch['entry_points'], batch['languages']): + + idx = sample_id + self.total[idx] += 1.0 + metric_result_dict['sample_id'].append(sample_id) + + code_gen = re.split( + r'\n[A-Za-z0-9#`]', + code_gen)[0] # remove everything after function ends + final_code = sample_prompt + code_gen # combine prompt with the code generation + metric_result_dict['context'].append(sample_prompt) + metric_result_dict['output'].append(code_gen) + + test_results = [] + for test_input, test_output in zip(test_inputs, test_outputs): + payload = { + 'code': final_code, + 'input': test_input, + 'output': test_output, + 'entry_point': entry_point, + 'language': language, + } + + result = client.invoke([[[payload]]])[0][0][0] + test_results.append(result) + + if all(test_results): + self.correct[idx] += 1.0 + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + client.close() # pyright: ignore [reportOptionalMemberAccess] + return metric_result_dict + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + complete = self.total == self.num_generations # so that eval subset batches can be used + + if complete.sum() < (self.total != 0).sum(): + warnings.warn( + 'Some samples in the dataset have less than the expected number of generations. ' + + + 'This is expected if you are using a subset of the dataset for evaluation.' + ) + + if (self.correct > self.total).any().item(): + raise ValueError( + 'Internal error some samples have more correct than total generations. This should not happen.' + ) + + results = {} + n = self.num_generations + + for k in self.pass_at_k: + pass_at_k = sum([ + self.estimator(n, int(c.item()), k) + for c in self.correct[complete] + ]) / complete.sum().item() + results[f'pass@{k}'] = torch.tensor(pass_at_k) + + if len(results) == 1: # backwards compatibility + return list(results.values())[0] + + return results + + +class InContextLearningExpectedCalibrationError(InContextLearningMetric): + """Generic class for Expected Calibration Error (ECE). + + Citation: https://arxiv.org/pdf/1706.04599.pdf + + Expected calibration error is calculated by dividing predictions into buckets based on the model's confidence (a probability value between 0 and 1). + We then calculate the accuracy within each bucket and calculate the average gap between confidence and accuracy + across buckets, weighted by the number of samples in each bucket. + + Each task must implement its own definition of "confidence" to be computed via the `update` method. + + Adds metric state variables: + bucket_totals (float): The number of instances where the prediction masked the target per bucket. + bucket_correct (float): The number of total instances that were predicted per bucket. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + n_buckets (int): Number of distinct buckets to split the confidence distribution into + """ + + def __init__(self, dist_sync_on_step: bool = False, n_buckets: int = 10): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.n_buckets = n_buckets + if n_buckets < 1: + raise Exception('`n_buckets`') + self.add_state('bucket_totals', + default=torch.zeros(n_buckets), + dist_reduce_fx='sum') + self.add_state('bucket_correct', + default=torch.zeros(n_buckets), + dist_reduce_fx='sum') + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + pass + + def compute(self): + assert isinstance(self.bucket_correct, Tensor) + assert isinstance(self.bucket_totals, Tensor) + + result = torch.tensor(0.0, device=self.bucket_correct.device) + total_obs = torch.sum(self.bucket_totals) + for i in range(self.n_buckets): + if self.bucket_totals[i] == 0: + continue + + acc_bucket_i = self.bucket_correct[i] / self.bucket_totals[i] + upper_bound = (i + 1) / self.n_buckets + lower_bound = i / self.n_buckets + conf_bucket_i = torch.tensor((upper_bound + lower_bound) / 2, + device=self.bucket_correct.device) + result += (self.bucket_totals[i] / + total_obs) * torch.abs(acc_bucket_i - conf_bucket_i) + return result + + +class InContextLearningMCExpectedCalibrationError( + InContextLearningExpectedCalibrationError): + r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) + + multiple choice (MC) tasks. (source: https://arxiv.org/abs/2012.00955). + + For MC tasks, the model confidence is defined as the softmax of average per-token probability assigned to the top question choice. + + See `InContextLearningExpectedCalibrationError` for more info. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + outputs = torch.softmax(outputs, dim=2) + probabilities = [] + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_logits = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1) + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + probability = cont_tok_logits.index_select( + dim=1, index=cont_tok_targ).diagonal().mean() + probabilities.append(probability) + + for (start, end), gold_idx in zip(batch['choice_groupings'], + batch['gold_indices']): + subset = probabilities[start:end] + idx_max = subset.index(max(subset)) + confidence = torch.tensor(subset).max() / torch.tensor(subset).sum() + + assert confidence >= 0.0 and confidence <= 1.0 + bucket_idx = int(confidence * self.n_buckets) + if bucket_idx == self.n_buckets: + bucket_idx -= 1 + + if idx_max == gold_idx: + self.bucket_correct[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + + self.bucket_totals[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + + +class InContextLearningLMExpectedCalibrationError( + InContextLearningExpectedCalibrationError): + r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) + + language modeling (LM) tasks. (cite: https://arxiv.org/pdf/1706.04599.pdf). + + For LM tasks, the model confidence is defined as the minimum probability assigned to all tokens in the continuation. + + See `InContextLearningExpectedCalibrationError` for more info. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): + + outputs = torch.softmax(outputs, dim=2) + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_logits = outputs[batch_idx].index_select(dim=0, + index=cont_idx - + 1) + cont_tok_pred = cont_tok_logits.argmax(dim=-1) + confidence = cont_tok_logits.max(dim=-1).values.min() + cont_tok_targ = labels[batch_idx].index_select(dim=0, + index=cont_idx - 1) + assert confidence >= 0.0 and confidence <= 1.0 + bucket_idx = int(confidence * self.n_buckets) + if bucket_idx == self.n_buckets: + bucket_idx -= 1 + + if (cont_tok_pred == cont_tok_targ).all(): + self.bucket_correct[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + + self.bucket_totals[ + bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 3c0a7ebd6e..24593144aa 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -26,6 +26,34 @@ entry_points=True, description=_fc_description) +_ffns_description = ( + 'The ffns registry is used to register functions that build ffn layers.' + + 'See ffn.py for examples.') +ffns = create_registry('llmfoundry', + 'ffns', + generic_type=Callable, + entry_points=True, + description=_ffns_description) + +_ffns_with_norm_description = ( + 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' + + 'See ffn.py for examples.') +ffns_with_norm = create_registry('llmfoundry', + 'ffns_with_norm', + generic_type=Callable, + entry_points=True, + description=_ffns_with_norm_description) + +_ffns_with_megablocks_description = ( + 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + + 'See ffn.py for examples.') +ffns_with_megablocks = create_registry( + 'llmfoundry', + 'ffns_with_megablocks', + generic_type=Callable, + entry_points=True, + description=_ffns_with_megablocks_description) + _attention_classes_description = ( 'The attention_classes registry is used to register classes that implement attention layers. See ' + 'attention.py for expected constructor signature.') @@ -45,8 +73,33 @@ entry_points=True, description=_attention_implementations_description) +_param_init_fns_description = ( + 'The param_init_fns registry is used to register functions that initialize parameters.' + + + 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' +) +param_init_fns = create_registry('llmfoundry', + 'param_init_fns', + generic_type=Callable[..., None], + entry_points=True, + description=_param_init_fns_description) + +_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. +These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. +They should take in the module, init_div_is_residual, and div_is_residual arguments.""" +module_init_fns = create_registry('llmfoundry', + 'module_init_fns', + generic_type=Callable[..., bool], + entry_points=True, + description=_module_init_fns_description) + __all__ = [ 'norms', + 'param_init_fns', + 'module_init_fns', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', 'fcs', diff --git a/llmfoundry/metrics/__init__.py b/llmfoundry/metrics/__init__.py index 6c71a3ea08..8ca2db5bd2 100644 --- a/llmfoundry/metrics/__init__.py +++ b/llmfoundry/metrics/__init__.py @@ -1,14 +1,15 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.metrics import (InContextLearningCodeEvalAccuracy, - InContextLearningLMAccuracy, - InContextLearningLMExpectedCalibrationError, - InContextLearningMCExpectedCalibrationError, - InContextLearningMultipleChoiceAccuracy, - InContextLearningQAAccuracy, MaskedAccuracy) -from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity +from composer.metrics import (LanguageCrossEntropy, LanguagePerplexity, + MaskedAccuracy) +from llmfoundry.eval.metrics import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, + InContextLearningMultipleChoiceAccuracy) from llmfoundry.metrics.token_acc import TokenAccuracy from llmfoundry.registry import metrics @@ -19,7 +20,8 @@ metrics.register('mc_expected_calibration_error', func=InContextLearningMCExpectedCalibrationError) metrics.register('mc_accuracy', func=InContextLearningMultipleChoiceAccuracy) -metrics.register('qa_accuracy', func=InContextLearningQAAccuracy) +metrics.register('qa_accuracy', + func=InContextLearningGenerationExactMatchAccuracy) metrics.register('code_eval_accuracy', func=InContextLearningCodeEvalAccuracy) metrics.register('language_cross_entropy', func=LanguageCrossEntropy) metrics.register('language_perplexity', func=LanguagePerplexity) @@ -54,11 +56,8 @@ 'InContextLearningLMExpectedCalibrationError', 'InContextLearningMCExpectedCalibrationError', 'InContextLearningMultipleChoiceAccuracy', - 'InContextLearningQAAccuracy', + 'InContextLearningGenerationExactMatchAccuracy', 'InContextLearningCodeEvalAccuracy', - 'LanguageCrossEntropy', - 'LanguagePerplexity', - 'MaskedAccuracy', 'DEFAULT_CAUSAL_LM_TRAIN_METRICS', 'DEFAULT_CAUSAL_LM_EVAL_METRICS', 'DEFAULT_ENC_DEC_METRICS', diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py index 3c35080d6e..2ed7b2d6e1 100644 --- a/llmfoundry/models/hf/__init__.py +++ b/llmfoundry/models/hf/__init__.py @@ -6,6 +6,7 @@ prepare_hf_enc_dec_model_for_fsdp, prepare_hf_model_for_fsdp) from llmfoundry.models.hf.hf_t5 import ComposerHFT5 +from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP __all__ = [ 'ComposerHFCausalLM', @@ -13,4 +14,5 @@ 'prepare_hf_causal_lm_model_for_fsdp', 'prepare_hf_enc_dec_model_for_fsdp', 'prepare_hf_model_for_fsdp', + 'HuggingFaceModelWithFSDP', ] diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index ddeb938c8d..87bffc3af8 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -14,6 +14,12 @@ if TYPE_CHECKING: from peft import PeftModel +__all__ = [ + 'prepare_hf_model_for_fsdp', + 'prepare_hf_causal_lm_model_for_fsdp', + 'prepare_hf_enc_dec_model_for_fsdp', +] + # helper functions def rhasattr(obj: Any, attr: str) -> bool: @@ -143,8 +149,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, causal_base_model = hf_get_causal_base_model(model) # OPT and olmo have an extra layer of wrapping, so special case here - if isinstance(causal_base_model, - OPTDecoder) or model.config.model_type == 'olmo': + if isinstance(causal_base_model, OPTDecoder): underlying_model = maybe_get_underlying_model(model) underlying_model.model._fsdp_wrap = False model_block = hf_get_hidden_layers(causal_base_model) @@ -251,7 +256,7 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, if encoder_block_type == decoder_block_type: return - # need to wrap encoder blocks separately for ProhpetNet and Marian + # need to wrap encoder blocks separately for ProphetNet and Marian model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type) model.activation_checkpointing_fn = lambda module: isinstance( module, encoder_block_type) diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 2ba88d390c..4b7be9ee08 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: from peft import PeftConfig +__all__ = ['HuggingFaceModelWithFSDP'] + # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index 9bb2ece2b2..905abf2fa1 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -2,16 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.inference_api_wrapper.fmapi import ( - FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper) + FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper, FMAPIEvalInterface) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( - OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) __all__ = [ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', + 'OpenAIEvalInterface', 'InferenceAPIEvalWrapper', 'FMAPICasualLMEvalWrapper', 'FMAPIChatAPIEvalWrapper', + 'FMAPIEvalInterface', ] diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py index 58ea302ace..d0c987304a 100644 --- a/llmfoundry/models/inference_api_wrapper/fmapi.py +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -15,6 +15,7 @@ __all__ = [ 'FMAPICasualLMEvalWrapper', 'FMAPIChatAPIEvalWrapper', + 'FMAPIEvalInterface', ] log = logging.getLogger(__name__) diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py index 4c30e7822d..a939d03d68 100644 --- a/llmfoundry/models/inference_api_wrapper/interface.py +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -5,14 +5,16 @@ import torch from composer.core.types import Batch -from composer.metrics import InContextLearningMetric from composer.models import ComposerModel from omegaconf import DictConfig from torchmetrics import Metric from transformers import AutoTokenizer +from llmfoundry.eval.metrics import InContextLearningMetric from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS +__all__ = ['InferenceAPIEvalWrapper'] + class InferenceAPIEvalWrapper(ComposerModel): diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index bacf71b8e2..9f2cf3315c 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -23,6 +23,7 @@ __all__ = [ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', + 'OpenAIEvalInterface', ] if TYPE_CHECKING: diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 5784fcd7e9..e31029024c 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -3,13 +3,18 @@ from llmfoundry.models.layers.attention import ( GroupedQueryAttention, MultiheadAttention, MultiQueryAttention, - attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, - scaled_multihead_dot_product_attention) -from llmfoundry.models.layers.blocks import MPTBlock + attn_bias_shape, build_alibi_bias, build_attn_bias, check_alibi_support, + flash_attn_fn, scaled_multihead_dot_product_attention) +from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding +from llmfoundry.models.layers.dmoe import DroplessMLP, LearnedRouter, dMoE from llmfoundry.models.layers.fc import * -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn -from llmfoundry.models.layers.norm import LPLayerNorm +from llmfoundry.models.layers.ffn import MPTGLU, MPTMLP +from llmfoundry.models.layers.layer_builders import (build_attention_layer, + build_fc, build_ffn, + build_norm) +from llmfoundry.models.layers.norm import (LPLayerNorm, LPRMSNorm, RMSNorm, + TritonRMSNorm, rms_norm) __all__ = [ 'scaled_multihead_dot_product_attention', @@ -20,10 +25,22 @@ 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', - 'MPTMLP', + 'check_alibi_support', 'MPTBlock', - 'LPLayerNorm', + 'FusedNormAttentionNorm', 'SharedEmbedding', - 'FFN_CLASS_REGISTRY', + 'dMoE', + 'LearnedRouter', + 'DroplessMLP', + 'MPTMLP', + 'MPTGLU', + 'build_attention_layer', 'build_ffn', + 'build_fc', + 'build_norm', + 'LPLayerNorm', + 'LPRMSNorm', + 'RMSNorm', + 'TritonRMSNorm', + 'rms_norm', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 6614d5d161..d4a34eecaa 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -18,6 +18,18 @@ attention_implementations) from llmfoundry.models.layers.layer_builders import build_fc, build_norm +__all__ = [ + 'scaled_multihead_dot_product_attention', + 'flash_attn_fn', + 'MultiheadAttention', + 'MultiQueryAttention', + 'GroupedQueryAttention', + 'attn_bias_shape', + 'build_attn_bias', + 'build_alibi_bias', + 'check_alibi_support', +] + def is_flash_v2_installed(v2_version: str = '2.0.0'): assert version.parse(v2_version) >= version.parse('2.0.0') diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 1ad9ec954f..d56c4753af 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -8,15 +8,20 @@ import torch import torch.nn as nn -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn +from llmfoundry.layers_registry import ffns_with_norm from llmfoundry.models.layers.layer_builders import (build_attention_layer, - build_norm) + build_ffn, build_norm) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip except: unpad_input, pad_input = None, None +__all__ = [ + 'MPTBlock', + 'FusedNormAttentionNorm', +] + attn_config_defaults: Dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, @@ -73,12 +78,15 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() + ffn_type = ffn_config['ffn_type'] + ffn_has_norm = ffn_type in ffns_with_norm + if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, attn_config=attn_config, - ffn_config=ffn_config, + ffn_has_norm=ffn_has_norm, fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, @@ -116,8 +124,7 @@ def __init__( }, ) self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], - '_has_norm', False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, @@ -125,12 +132,14 @@ def __init__( ) self.ffn = build_ffn( + name=ffn_type, d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, - **ffn_config, + ffn_kwargs=ffn_config, ) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn @@ -198,7 +207,7 @@ def __init__( d_model: int, n_heads: int, attn_config: Optional[Dict] = None, - ffn_config: Optional[Dict] = None, + ffn_has_norm: bool = False, fc_type: str = 'torch', resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -208,7 +217,6 @@ def __init__( ): super().__init__() assert attn_config is not None - assert ffn_config is not None assert isinstance(attn_config['attn_type'], str) # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs @@ -238,9 +246,9 @@ def __init__( **attn_config_subset_for_attn_class }, ) + self.norm_2 = None - if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', - False): + if not ffn_has_norm: self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, diff --git a/llmfoundry/models/layers/custom_embedding.py b/llmfoundry/models/layers/custom_embedding.py index 20a2be3a55..fba823a4f7 100644 --- a/llmfoundry/models/layers/custom_embedding.py +++ b/llmfoundry/models/layers/custom_embedding.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from torch import Tensor +__all__ = ['SharedEmbedding'] + class SharedEmbedding(nn.Embedding): diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index 1a981b61c5..f2b255294c 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -1,10 +1,18 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Callable, Optional import torch +__all__ = [ + 'dMoE', + 'LearnedRouter', + 'MLP', + 'GLU', + 'DroplessMLP', +] + # Add option to route tokens uniformly across experts. We use # a custom autograd op router backwards is still run for benchmarking. @@ -24,7 +32,8 @@ class LearnedRouter(torch.nn.Module): def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: float, moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, device: torch.device) -> None: + uniform_expert_assignment: bool, + device: Optional[torch.device]) -> None: super().__init__() self.hidden_size: int = hidden_size self.moe_num_experts: int = moe_num_experts @@ -84,7 +93,7 @@ def __init__( ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, - device: torch.device, + device: Optional[torch.device], ) -> None: super().__init__() @@ -117,9 +126,14 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: class GLU(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, activation_fn: Callable, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + activation_fn: Callable, + device: Optional[torch.device], + ): super().__init__() self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size @@ -157,9 +171,16 @@ def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): class DroplessMLP(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, - moe_num_experts: int, activation_fn: Callable, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + mlp_type: str, + moe_num_experts: int, + activation_fn: Callable, + bias: bool, + device: Optional[torch.device], + ): super().__init__() self.moe_num_experts = moe_num_experts @@ -209,12 +230,20 @@ def forward(self, x: torch.Tensor, scores: torch.Tensor, class dMoE(torch.nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, moe_top_k: int, mlp_type: str, - activation_fn: Callable, moe_jitter_eps: float, - moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, bias: bool, - device: torch.device): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + mlp_type: str, + activation_fn: Callable, + moe_jitter_eps: float, + moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, + bias: bool, + device: Optional[torch.device], + ): super().__init__() # Token router. diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index f0b499875a..c64e87cb9a 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -10,8 +10,11 @@ import torch import torch.nn as nn +from torch.distributed import ProcessGroup from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard +from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, + ffns_with_norm) from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.layer_builders import build_fc @@ -29,6 +32,17 @@ log = logging.getLogger(__name__) +__all__ = [ + 'MPTMLP', + 'MPTGLU', + 'build_mptglu', + 'build_mptmlp', + 'build_te_ln_mlp', + 'build_torch_dmoe', + 'build_mb_moe', + 'build_mb_dmoe', +] + _FFN_ACT_FN_DEFAULT = { 'name': 'gelu', 'approximate': 'none', @@ -172,25 +186,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) -FFN_CLASS_REGISTRY = { - 'mptmlp': MPTMLP, - 'mptglu': MPTGLU, - 'torch_dmoe': dMoE, -} - -if is_te_imported: - import transformer_engine.pytorch as te - te.LayerNormMLP._has_norm = True - FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP +def build_mptglu( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTGLU( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) -if is_megablocks_imported: - import megablocks - FFN_CLASS_REGISTRY['mb_moe'] = megablocks.layers.moe.MoE - FFN_CLASS_REGISTRY['mb_dmoe'] = megablocks.layers.dmoe.dMoE +def build_mptmlp( + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, +) -> nn.Module: + return MPTMLP( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) -def build_ffn( +def build_te_ln_mlp( d_model: int, expansion_ratio: Union[int, float], fc_type: str = 'torch', @@ -200,131 +236,225 @@ def build_ffn( bias: bool = True, **kwargs: Any, ) -> nn.Module: - ffn_type = kwargs.pop('ffn_type') - if ffn_type in ['mptmlp', 'mptglu']: - if len(kwargs) > 0: - raise ValueError( - f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}' - ) - return FFN_CLASS_REGISTRY[ffn_type]( - d_model=d_model, - expansion_ratio=expansion_ratio, - fc_type=fc_type, - act_fn=resolve_ffn_act_fn(ffn_act_fn), - ffn_hidden_size=ffn_hidden_size, - device=device, - bias=bias, + assert te is not None + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + if ffn_act_fn is not None: + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' ) - elif ffn_type == 'te_ln_mlp': - if te is None: - raise RuntimeError( - 'Requirements for TransformerEngine not installed; see install instructions in `README.md`.' - ) - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - if ffn_act_fn is not None: - raise ValueError( - f'Transformer Engine block does not support custom activation functions.' - ) - return te.LayerNormMLP( - hidden_size=d_model, - ffn_hidden_size=ffn_hidden_size, - bias=bias, - **kwargs, - ) - elif ffn_type in ('mb_moe', 'mb_dmoe'): - if megablocks is None: - raise RuntimeError( - 'Requirements for megablocks not installed; see install instructions in `README.md`.' - ) - args = kwargs['args'] - args.bias = bias - args.hidden_size = d_model - args.device = device + return te.LayerNormMLP( + hidden_size=d_model, + ffn_hidden_size=ffn_hidden_size, + bias=bias, + **kwargs, + ) - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) - args.ffn_hidden_size = ffn_hidden_size - - if ffn_act_fn is not None: - args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) - - moe_world_size = 1 - expert_parallel_group = args.expert_parallel_group - if expert_parallel_group is not None: - moe_world_size = expert_parallel_group.size() - if kwargs.get('moe_world_size') != moe_world_size: - raise RuntimeError( - f'MoE expert_parallel_group configured with incorrect world size.' - ) - if ffn_type == 'mb_moe': - ffn = megablocks.layers.moe.MoE(args) - - # Fused initialization setup - # For param_init_fn, enables shape based init of stacked layers - ffn.experts.mlp._stack_dim = 0 - elif ffn_type == 'mb_dmoe': - ffn = megablocks.layers.dmoe.dMoE(args) - - # Fused initialization setup - # For param_init_fn, enables shape based init of fused layers - n_exp = min(1, args.moe_num_experts // moe_world_size) - ffn.experts.mlp._fused = (0, [ - (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) - ]) +def build_torch_dmoe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + moe_num_experts = kwargs.pop('moe_num_experts') + moe_top_k = kwargs.pop('moe_top_k') + mlp_type = kwargs.pop('mlp_type') + moe_jitter_eps = kwargs.pop('moe_jitter_eps') + moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights') + uniform_expert_assignment = kwargs.pop('uniform_expert_assignment') + + fc_type = kwargs.pop('fc_type', 'torch') + del fc_type # Unused + + if len(kwargs) > 0: + raise ValueError(f'Invalid arguments to torch dmoe: {kwargs}.') + + return dMoE( + hidden_size=d_model, + ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size), + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + mlp_type=mlp_type, + bias=bias, + moe_jitter_eps=moe_jitter_eps, + activation_fn=resolve_ffn_act_fn(ffn_act_fn), + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + device=torch.device(device) if device is not None else None, + ) + + +def _mb_setup_args( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int], + ffn_act_fn: Optional[dict], + device: Optional[str], + bias: bool, + kwargs: dict[str, Any], +) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: + if megablocks is None: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + args = kwargs['args'] + args.bias = bias + args.hidden_size = d_model + args.device = device + + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) + args.ffn_hidden_size = ffn_hidden_size + + if ffn_act_fn is not None: + args.activation_fn = resolve_ffn_act_fn(ffn_act_fn) + + moe_world_size = 1 + expert_parallel_group = args.expert_parallel_group + if expert_parallel_group is not None: + moe_world_size = expert_parallel_group.size() + if kwargs.get('moe_world_size') != moe_world_size: + raise RuntimeError( + f'MoE expert_parallel_group configured with incorrect world size.') + + return args, moe_world_size, expert_parallel_group + + +def _patch_ffn_mb( + ffn: nn.Module, + moe_world_size: int, + expert_parallel_group: ProcessGroup, + device_mesh: DeviceMesh, + args: 'megablocks.layers.arguments.Arguments', +): + # Attach args to MLP directly for use in param_init_fn + ffn.experts.mlp.hidden_size = args.ffn_hidden_size + ffn.experts.mlp.expert_parallel_group = expert_parallel_group + ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group + + if moe_world_size > 1: + expert_mesh = device_mesh['expert_parallel'] + expert_placements: List[Placement] = [Shard(0)] + # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() + dtensorified_params = [ + (name, + dtensorify_param(param=parameter, + mesh=expert_mesh, + placements=expert_placements)) + for name, parameter in ffn.experts.mlp.named_parameters() + ] + for name, dtensorified_param in dtensorified_params: + ffn.experts.mlp.register_parameter(name, dtensorified_param) + + if device_mesh.mesh.ndim == 2: + submesh = device_mesh['weight_parallel'] + elif device_mesh.mesh.ndim == 3: + raise RuntimeError(f'HSDP + MoE is not supported.') else: - raise RuntimeError(f'Invalid ffn_type option: {ffn_type}.') - - # Attach args to MLP directly for use in param_init_fn - ffn.experts.mlp.hidden_size = args.ffn_hidden_size - ffn.experts.mlp.expert_parallel_group = expert_parallel_group - ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group - - if moe_world_size > 1: - device_mesh = kwargs['device_mesh'] - - expert_mesh = device_mesh['expert_parallel'] - expert_placements: List[Placement] = [Shard(0)] - # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() - dtensorified_params = [ - (name, - dtensorify_param(param=parameter, - mesh=expert_mesh, - placements=expert_placements)) - for name, parameter in ffn.experts.mlp.named_parameters() - ] - for name, dtensorified_param in dtensorified_params: - ffn.experts.mlp.register_parameter(name, dtensorified_param) - - device_mesh = kwargs['device_mesh'] - if device_mesh.mesh.ndim == 2: - submesh = device_mesh['weight_parallel'] - elif device_mesh.mesh.ndim == 3: - raise RuntimeError(f'HSDP + MoE is not supported.') - else: - raise ValueError( - f'{device_mesh.mesh.ndim=} not supported for MoE.') - - ffn.experts._fsdp_kwargs_dict = { - 'device_mesh': submesh, - } - return ffn - elif ffn_type == 'torch_dmoe': - return dMoE( - hidden_size=d_model, - ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size), - moe_num_experts=kwargs.pop('moe_num_experts'), - moe_top_k=kwargs.pop('moe_top_k'), - mlp_type=kwargs.pop('mlp_type'), - bias=bias, - moe_jitter_eps=kwargs.pop('moe_jitter_eps'), - activation_fn=resolve_ffn_act_fn(ffn_act_fn), - moe_normalize_expert_weights=kwargs.pop( - 'moe_normalize_expert_weights'), - uniform_expert_assignment=kwargs.pop('uniform_expert_assignment'), - device=device, # pyright: ignore[reportGeneralTypeIssues] + raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.') + + ffn.experts._fsdp_kwargs_dict = { + 'device_mesh': submesh, + } + + +def build_mb_moe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + if not is_megablocks_imported: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' + ) + + args, moe_world_size, expert_parallel_group = _mb_setup_args( + d_model=d_model, + expansion_ratio=expansion_ratio, + ffn_hidden_size=ffn_hidden_size, + ffn_act_fn=ffn_act_fn, + device=device, + bias=bias, + kwargs=kwargs, + ) + + ffn = megablocks.layers.moe.MoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of stacked layers + ffn.experts.mlp._stack_dim = 0 + + _patch_ffn_mb( + ffn=ffn, + moe_world_size=moe_world_size, + expert_parallel_group=expert_parallel_group, + device_mesh=kwargs['device_mesh'], + args=args, + ) + + return ffn + + +def build_mb_dmoe( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, + device: Optional[str] = None, + bias: bool = True, + **kwargs: Any, +) -> nn.Module: + if not is_megablocks_imported: + raise RuntimeError( + 'Requirements for megablocks not installed; see install instructions in `README.md`.' ) - raise ValueError(f'{ffn_type=} not recognized.') + args, moe_world_size, expert_parallel_group = _mb_setup_args( + d_model=d_model, + expansion_ratio=expansion_ratio, + ffn_hidden_size=ffn_hidden_size, + ffn_act_fn=ffn_act_fn, + device=device, + bias=bias, + kwargs=kwargs, + ) + + ffn = megablocks.layers.dmoe.dMoE(args) + + # Fused initialization setup + # For param_init_fn, enables shape based init of fused layers + n_exp = min(1, args.moe_num_experts // moe_world_size) + ffn.experts.mlp._fused = (0, [ + (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) + ]) + + _patch_ffn_mb( + ffn=ffn, + moe_world_size=moe_world_size, + expert_parallel_group=expert_parallel_group, + device_mesh=kwargs['device_mesh'], + args=args, + ) + + return ffn + + +ffns.register('mptglu', func=build_mptglu) +ffns.register('mptmlp', func=build_mptmlp) +ffns.register('torch_dmoe', func=build_torch_dmoe) + +if is_te_imported: + ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp) + +if is_megablocks_imported: + ffns_with_megablocks.register('mb_moe', func=build_mb_moe) + ffns_with_megablocks.register('mb_dmoe', func=build_mb_dmoe) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 6a725d469a..ceb41d8d41 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -5,9 +5,18 @@ import torch -from llmfoundry.layers_registry import attention_classes, fcs, norms +from llmfoundry.layers_registry import (attention_classes, fcs, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.utils.registry_utils import construct_from_registry +__all__ = [ + 'build_attention_layer', + 'build_ffn', + 'build_fc', + 'build_norm', +] + def build_norm( name: str, @@ -25,6 +34,50 @@ def build_norm( kwargs=kwargs) +def build_ffn( + name: str, + d_model: int, + expansion_ratio: float, + device: Optional[str], + bias: bool, + ffn_kwargs: Dict[str, Any], +): + + registry_to_use = ffns + if name in ffns_with_norm: + registry_to_use = ffns_with_norm + + if name in ffns_with_megablocks: + registry_to_use = ffns_with_megablocks + + kwargs = { + 'd_model': d_model, + 'expansion_ratio': expansion_ratio, + 'device': device, + 'bias': bias, + **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}, + } + + def _validation_function(maybe_module: Any): + if not isinstance(maybe_module, torch.nn.Module): + raise ValueError(f'Function {name} must return a torch.nn.Module.') + + result = construct_from_registry( + name=name, + registry=registry_to_use, + post_validation_function=_validation_function, + partial_function=False, + kwargs=kwargs) + + if name in ffns_with_norm: + result._has_norm = True + + if name in ffns_with_megablocks: + result._uses_megablocks = True + + return result + + def build_attention_layer( name: str, attn_kwargs: Dict[str, Any], diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 92d295c71c..23b92015e7 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -7,6 +7,14 @@ from llmfoundry.layers_registry import norms +__all__ = [ + 'LPLayerNorm', + 'LPRMSNorm', + 'RMSNorm', + 'TritonRMSNorm', + 'rms_norm', +] + norms.register(name='layernorm', func=torch.nn.LayerNorm) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 4b98fa611d..dbee232f3d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,6 +8,7 @@ from transformers import PretrainedConfig +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import (check_alibi_support, is_flash_v2_installed) from llmfoundry.models.layers.blocks import attn_config_defaults @@ -17,8 +18,7 @@ # Otherwise, certain modules are missing. # isort: off from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm, build_fc # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note) from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note) from llmfoundry.layers_registry import norms # type: ignore (see note) from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) @@ -290,7 +290,7 @@ def _validate_config(self) -> None: ) elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type - elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']: + elif self.ffn_config['ffn_type'] in ffns_with_megablocks: self.ffn_config['return_bias'] = False elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4a8f3943af..1ef62a3b19 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -20,6 +20,7 @@ from composer.models import HuggingFaceModel from composer.utils import dist +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): @@ -42,12 +43,11 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig from llmfoundry.models.utils.config_moe_args import config_moe_args @@ -62,8 +62,8 @@ init_empty_weights # type: ignore (see note) from llmfoundry.models.utils.param_init_fns import ( generic_param_init_fn_, # type: ignore (see note) - MODEL_INIT_REGISTRY, ) +from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, build_act_ckpt_mod_to_blocks, @@ -324,7 +324,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() - if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], config.d_model, @@ -332,6 +332,7 @@ def __init__(self, config: MPTConfig): config.n_layers, ) self.mb_args = block_args['ffn_config'].get('args') + self.blocks = nn.ModuleList([ MPTBlock( device=config.init_device, @@ -676,7 +677,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, @@ -836,7 +837,7 @@ def forward( # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name]( + param_init_fns.get(init_fn_name)( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, @@ -1026,7 +1027,7 @@ def get_targets(self, batch: Mapping) -> torch.Tensor: return targets def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # Clear MegaBlocks MoE load balancing loss cache try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import clear_load_balancing_loss @@ -1053,7 +1054,7 @@ def loss(self, outputs: CausalLMOutputWithPast, else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors from megablocks.layers.moe import batched_load_balancing_loss diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 41313b8729..45a5f757f6 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -1,20 +1,24 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.models.utils.act_ckpt import (build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap, + pass_on_block_idx) from llmfoundry.models.utils.config_moe_args import config_moe_args from llmfoundry.models.utils.meta_init_context import (init_empty_weights, init_on_device) from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, mpt_get_total_params) -from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, - generic_param_init_fn_) +from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ __all__ = [ 'init_empty_weights', 'init_on_device', 'generic_param_init_fn_', - 'MODEL_INIT_REGISTRY', 'config_moe_args', 'mpt_get_active_params', 'mpt_get_total_params', + 'build_act_ckpt_mod_to_blocks', + 'pass_on_block_idx', + 'check_mapping_blocks_overlap', ] diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index fea68492c1..ef9a851a09 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,9 +5,16 @@ import torch -from llmfoundry.layers_registry import attention_classes, norms +from llmfoundry.layers_registry import (attention_classes, ffns, + ffns_with_megablocks, ffns_with_norm, + norms) from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY + +__all__ = [ + 'build_act_ckpt_mod_to_blocks', + 'pass_on_block_idx', + 'check_mapping_blocks_overlap', +] def pass_on_block_idx(parent: torch.nn.Module): @@ -28,14 +35,19 @@ def get_act_ckpt_module(mod_name: str) -> Any: mod_type = attention_classes.get(mod_name) elif mod_name.lower() == 'norm_attn_norm': mod_type = FusedNormAttentionNorm - elif mod_name in FFN_CLASS_REGISTRY: - mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in ffns: + mod_type = ffns.get(mod_name) + elif mod_name in ffns_with_norm: + mod_type = ffns_with_norm.get(mod_name) + elif mod_name in ffns_with_megablocks: + mod_type = ffns_with_megablocks.get(mod_name) elif mod_name in norms: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(attention_classes.get_all()) + - list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) + + list(attention_classes.keys()) + list(ffns.get_all()) + + list(ffns_with_norm.get_all()) + + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' @@ -98,7 +110,7 @@ def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: candidate_block_ids.extend(to_add) else: raise ValueError( - f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}' + f'target_blocks must be either a single integer, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}' ) candidate_block_ids = list(set(candidate_block_ids)) diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 1f7132c281..859dd3c52b 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -9,8 +9,13 @@ from packaging import version from torch import distributed +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size +__all__ = [ + 'config_moe_args', +] + def create_process_group_ranks(ranks: tuple[int]): """Creates a new distributed group. @@ -69,7 +74,7 @@ def config_megablocks_moe_args( groups can be initialized and shared across all blocks in the network. Args: - ffn_config (dict): FFN configuation before the MegaBlocks MoE is configured. + ffn_config (dict): FFN configuration before the MegaBlocks MoE is configured. d_model (int): Hidden size of the network. expansion_ratio (Union[int, float]): Expansion ratio in FFN. n_layers (int): Number of blocks used in the network. @@ -169,7 +174,7 @@ def config_moe_args( """Configures `ffn_config` for MoE. Args: - ffn_config (dict): FFN configuation before the MoE is configured. + ffn_config (dict): FFN configuration before the MoE is configured. d_model (int): Hidden size of the network. expansion_ratio (int, float): Expansion ratio in FFN. n_layers (int): Number of blocks used in the network. @@ -177,7 +182,7 @@ def config_moe_args( Returns: ffn_config (dict): FFN configuration with MoE configured. """ - if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if ffn_config['ffn_type'] in ffns_with_megablocks: return config_megablocks_moe_args( ffn_config=ffn_config, d_model=d_model, diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py index d72a289a73..66f06db581 100644 --- a/llmfoundry/models/utils/meta_init_context.py +++ b/llmfoundry/models/utils/meta_init_context.py @@ -23,6 +23,11 @@ import torch.nn as nn from torch.distributed._tensor import DTensor +__all__ = [ + 'init_empty_weights', + 'init_on_device', +] + @contextmanager def init_empty_weights(include_buffers: bool = False): diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index d90929713b..ca487ecca0 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -16,6 +16,13 @@ from torch import Tensor, nn from torch.distributed._tensor import DTensor +from llmfoundry.layers_registry import ffns_with_megablocks + +__all__ = [ + 'mpt_get_active_params', + 'mpt_get_total_params', +] + def module_n_params(module: nn.Module) -> int: """Gets the number of parameters in this module excluding child modules. @@ -127,7 +134,7 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore def mpt_get_total_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -138,14 +145,14 @@ def mpt_get_total_params(mpt_model) -> int: # type: ignore Returns: An int for the total number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: return megablocks_n_total_params(mpt_model) else: return sum(p.numel() for p in mpt_model.parameters()) def mpt_get_active_params(mpt_model) -> int: # type: ignore - """Calculates the total paramter count of an MPT model. + """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. @@ -156,7 +163,7 @@ def mpt_get_active_params(mpt_model) -> int: # type: ignore Returns: An int for the active number of parameters in this MPT model. """ - if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'): + if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: params = megablocks_n_active_params(mpt_model) else: params = sum(p.numel() for p in mpt_model.parameters()) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index bd409dee36..06bdd84438 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -12,7 +12,8 @@ from torch import nn from torch.distributed._tensor import DTensor -from llmfoundry.layers_registry import fcs, norms +from llmfoundry.layers_registry import (fcs, module_init_fns, norms, + param_init_fns) from llmfoundry.models.layers.dmoe import GLU, MLP try: @@ -25,6 +26,10 @@ except: megablocks = None +__all__ = [ + 'generic_param_init_fn_', +] + def torch_default_param_init_fn_( module: nn.Module, @@ -53,7 +58,7 @@ def fused_init_helper_( Args: module (nn.Module): The module to initialize. init_fn_ (Callable): Initialization method. - name_param (str): Name of parameter to initalize within the module. + name_param (str): Name of parameter to initialize within the module. """ _fused = getattr(module, '_fused', None) if _fused is None: @@ -89,7 +94,7 @@ def stacked_init_helper_( init_fn_: Callable, name_param: str = 'weight', ): - """Initializes parameters stacked along a new dimention. + """Initializes parameters stacked along a new dimension. Parameter initialization is often based on the parameters shape. If a layer is stacked, initialization should be based on the shapes of the original tensor instead of the @@ -99,7 +104,7 @@ def stacked_init_helper_( Args: module (nn.Module): The module to initialize. init_fn_ (Callable): Initialization method. - name_param (str): Name of parameter to initalize within the module. + name_param (str): Name of parameter to initialize within the module. """ stack_dim = getattr(module, '_stack_dim', None) if stack_dim is None: @@ -113,12 +118,12 @@ def stacked_param_init_helper( init_fn_: Callable, stack_dim: int, ): - """Initialize parameters stacked along a new dimention. + """Initialize parameters stacked along a new dimension. Args: param (torch.Tensor): Tensor to initialize. init_fn_ (Callable): Initialization method. - stack_dim (int): Dimention along with parameters are stacked + stack_dim (int): Dimension along with parameters are stacked """ p_ndims = param.ndim @@ -147,39 +152,14 @@ def _flip_fan_mode(init_fn_: Callable): return _init_fn_ -def generic_param_init_fn_( +def fc_init( module: nn.Module, init_fn_: Callable, - n_layers: int, - d_model: Optional[int] = None, - init_div_is_residual: Union[int, float, str, bool] = True, - emb_init_std: Optional[float] = None, - emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: Optional[float], **kwargs: Any, -) -> None: - del kwargs # unused, just to capture any extra args from the config - # enable user to divide _is_residual weights by - - # a value which defaults to math.sqrt(2 * cfg.n_layers) - init_div_is_residual = init_div_is_residual - - if init_div_is_residual is False: - # not used, for pyright - div_is_residual = 1.0 - elif init_div_is_residual is True: - div_is_residual = math.sqrt(2 * n_layers) - elif isinstance(init_div_is_residual, float) or isinstance( - init_div_is_residual, int): - div_is_residual = init_div_is_residual - elif init_div_is_residual.isnumeric(): - # do not trust YAML parsing to always convert numbers to numbers - div_is_residual = float(init_div_is_residual) - else: - # not used, for pyright - div_is_residual = 1.0 - raise ValueError( - f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' - ) +) -> bool: + del kwargs # unused, just to capture any extra args if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))): # Linear @@ -195,8 +175,21 @@ def generic_param_init_fn_( module, '_is_residual', False): with torch.no_grad(): module.weight.div_(div_is_residual) # type: ignore + return True + + return False + - elif isinstance(module, nn.Embedding): +def embedding_init( + module: nn.Module, + init_fn_: Callable, + emb_init_std: Optional[float], + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]], + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.Embedding): # Embedding if emb_init_std is not None: std = emb_init_std @@ -223,8 +216,19 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, - tuple(set([norms.get(name) for name in norms.get_all()]))): + return True + + return False + + +def norm_init( + module: nn.Module, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, + tuple(set([norms.get(name) for name in norms.get_all()]))): # Norm if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): @@ -232,7 +236,22 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.MultiheadAttention): + return True + + return False + + +def multihead_attention_init( + module: nn.Module, + init_fn_: Callable, + d_model: Optional[int], + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if isinstance(module, nn.MultiheadAttention): # torch's MultiheadAttention if module._qkv_same_embed_dim: assert module.in_proj_weight is not None @@ -267,7 +286,19 @@ def generic_param_init_fn_( if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) - elif te is not None and isinstance(module, te.LayerNormMLP): + return True + + return False + + +def te_layernorm_mlp_init( + module: nn.Module, + init_fn_: Callable, + **kwargs: Any, +) -> bool: + del kwargs # unused, just to capture any extra args + + if te is not None and isinstance(module, te.LayerNormMLP): if isinstance(module.layer_norm_weight, torch.Tensor): torch.nn.init.ones_(module.layer_norm_weight) if isinstance(module.layer_norm_bias, torch.Tensor): @@ -285,7 +316,19 @@ def generic_param_init_fn_( with torch.no_grad(): module.fc2_weight.div_(div_is_residual) # type: ignore - elif megablocks is not None and isinstance(module, ( + return True + + return False + + +def moe_init( + module: nn.Module, + init_fn_: Callable, + init_div_is_residual: Union[int, float, str, bool], + div_is_residual: float, + **kwargs: Any, +) -> bool: + if megablocks is not None and isinstance(module, ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, megablocks.layers.moe.ParallelMLP, @@ -294,32 +337,96 @@ def generic_param_init_fn_( if hasattr(module, 'bias') and module.bias is not None: # Initialize bias to 0 torch.nn.init.zeros_(module.bias) # type: ignore + return True elif megablocks is not None and isinstance(module, megablocks.layers.glu.SparseGLU): _megablocks_sparse_glu_generic_param_init_fn_( module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif megablocks is not None and isinstance(module, megablocks.layers.mlp.SparseMLP): _megablocks_sparse_mlp_generic_param_init_fn_( module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif megablocks is not None and isinstance(module, megablocks.layers.mlp.MLP): _megablocks_mlp_generic_param_init_fn_(module, init_fn_, bool(init_div_is_residual), div_is_residual) + return True elif isinstance(module, GLU): init_fn_(module.w1) init_fn_(module.v1) init_fn_(module.w2) + return True elif isinstance(module, MLP): init_fn_(module.w1) init_fn_(module.w2) + return True + + return False + + +def generic_param_init_fn_( + module: nn.Module, + init_fn_: Callable, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + **kwargs: Any, +) -> None: + del kwargs # unused, just to capture any extra args from the config + # enable user to divide _is_residual weights by + + # a value which defaults to math.sqrt(2 * cfg.n_layers) + init_div_is_residual = init_div_is_residual + + if init_div_is_residual is False: + # not used, for pyright + div_is_residual = 1.0 + elif init_div_is_residual is True: + div_is_residual = math.sqrt(2 * n_layers) + elif isinstance(init_div_is_residual, float) or isinstance( + init_div_is_residual, int): + div_is_residual = init_div_is_residual + elif init_div_is_residual.isnumeric(): + # do not trust YAML parsing to always convert numbers to numbers + div_is_residual = float(init_div_is_residual) else: + # not used, for pyright + div_is_residual = 1.0 + raise ValueError( + f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' + ) + + all_module_init_fns = [ + module_init_fns.get(name) for name in module_init_fns.get_all() + ] + did_init = False + for module_init_fn in all_module_init_fns: + did_init = module_init_fn( + module=module, + init_fn_=init_fn_, + d_model=d_model, + init_div_is_residual=init_div_is_residual, + div_is_residual=div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + ) + + if did_init: + break + + if not did_init: for _ in module.parameters(recurse=False): # raise error if uninitialized module has any parameters raise NotImplementedError( - f'{module.__class__.__name__} parameters are not initialized by param_init_fn.' - ) + f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + + + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + + ', '.join(module_init_fns.get_all())) def _megablocks_sparse_mlp_generic_param_init_fn_( @@ -725,13 +832,18 @@ def xavier_normal_param_init_fn_( ) -MODEL_INIT_REGISTRY = { - 'default_': torch_default_param_init_fn_, - 'baseline_': baseline_param_init_fn_, - 'kaiming_uniform_': kaiming_uniform_param_init_fn_, - 'kaiming_normal_': kaiming_normal_param_init_fn_, - 'neox_init_': neox_param_init_fn_, - 'small_init_': small_param_init_fn_, - 'xavier_uniform_': xavier_uniform_param_init_fn_, - 'xavier_normal_': xavier_normal_param_init_fn_, -} +param_init_fns.register('default_', func=torch_default_param_init_fn_) +param_init_fns.register('baseline_', func=baseline_param_init_fn_) +param_init_fns.register('kaiming_uniform_', func=kaiming_uniform_param_init_fn_) +param_init_fns.register('kaiming_normal_', func=kaiming_normal_param_init_fn_) +param_init_fns.register('neox_init_', func=neox_param_init_fn_) +param_init_fns.register('small_init_', func=small_param_init_fn_) +param_init_fns.register('xavier_uniform_', func=xavier_uniform_param_init_fn_) +param_init_fns.register('xavier_normal_', func=xavier_normal_param_init_fn_) + +module_init_fns.register('fc', func=fc_init) +module_init_fns.register('embedding', func=embedding_init) +module_init_fns.register('norm', func=norm_init) +module_init_fns.register('multihead_attention', func=multihead_attention_init) +module_init_fns.register('te_layernorm_mlp', func=te_layernorm_mlp_init) +module_init_fns.register('moe', func=moe_init) diff --git a/llmfoundry/optim/__init__.py b/llmfoundry/optim/__init__.py index 527969bd63..26389665b5 100644 --- a/llmfoundry/optim/__init__.py +++ b/llmfoundry/optim/__init__.py @@ -26,4 +26,5 @@ 'DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion', + 'InverseSquareRootWithWarmupScheduler', ] diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index 0ce76e905e..9b2dac9d80 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -13,6 +13,11 @@ log = logging.getLogger(__name__) +__all__ = [ + 'DecoupledAdaLRLion', + 'DecoupledClipLion', +] + class DecoupledAdaLRLion(Optimizer): """DecoupledAdaLRLion. diff --git a/llmfoundry/optim/lion.py b/llmfoundry/optim/lion.py index 0caa7d2877..b04211649c 100644 --- a/llmfoundry/optim/lion.py +++ b/llmfoundry/optim/lion.py @@ -11,6 +11,10 @@ log = logging.getLogger(__name__) +__all__ = [ + 'DecoupledLionW', +] + class DecoupledLionW(Optimizer): metric_functions = { diff --git a/llmfoundry/optim/outlier_detection.py b/llmfoundry/optim/outlier_detection.py index 9df4381ba4..e430f4ccb5 100644 --- a/llmfoundry/optim/outlier_detection.py +++ b/llmfoundry/optim/outlier_detection.py @@ -4,6 +4,8 @@ import collections from typing import Optional +__all__ = ['OutlierDetector'] + class OutlierDetector: """OutlierDetector. diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index ef6629be10..6e1824ea08 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -13,7 +13,9 @@ from llmfoundry.interfaces import CallbackWithConfig from llmfoundry.layers_registry import (attention_classes, - attention_implementations, fcs, norms) + attention_implementations, fcs, ffns, + ffns_with_megablocks, ffns_with_norm, + module_init_fns, norms, param_init_fns) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -131,6 +133,11 @@ 'metrics', 'dataloaders', 'norms', + 'param_init_fns', + 'module_init_fns', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', 'fcs', diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 298e1bc984..0ecaa45b5f 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -6,6 +6,10 @@ from transformers import PreTrainedTokenizer +__all__ = [ + 'TiktokenTokenizerWrapper', +] + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.""" diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index e8c90a6007..2c3d7c9bc3 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -2,7 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.utils.builders import (build_algorithm, build_callback, - build_logger, build_optimizer, + build_composer_model, build_evaluators, + build_icl_data_and_gauntlet, + build_icl_evaluators, build_logger, + build_metric, build_optimizer, build_scheduler, build_tokenizer) from llmfoundry.utils.checkpoint_conversion_helpers import ( convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict, @@ -12,7 +15,7 @@ process_init_device, update_batch_size_info) from llmfoundry.utils.data_prep_utils import (DownloadingIterable, - merge_shard_groups, with_id) + merge_shard_groups) from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility from llmfoundry.utils.logging_utils import SpecificWarningFilter @@ -30,40 +33,53 @@ from llmfoundry.utils.prompt_files import load_prompts, load_prompts_from_file from llmfoundry.utils.registry_utils import (TypedRegistry, construct_from_registry, - create_registry) -from llmfoundry.utils.warnings import VersionedDeprecationWarning + create_registry, import_file, + save_registry) +from llmfoundry.utils.warnings import (ExperimentalWarning, + VersionedDeprecationWarning, + experimental_class, + experimental_function) __all__ = [ 'build_algorithm', 'build_callback', + 'build_evaluators', + 'build_icl_data_and_gauntlet', + 'build_icl_evaluators', 'build_logger', 'build_optimizer', 'build_scheduler', 'build_tokenizer', - 'convert_and_save_ft_weights', + 'build_composer_model', + 'build_metric', 'get_hf_tokenizer_from_composer_state_dict', 'load_tokenizer', - 'calculate_batch_size_info', - 'log_config', + 'convert_and_save_ft_weights', 'pop_config', + 'calculate_batch_size_info', 'update_batch_size_info', 'process_init_device', + 'log_config', 'DownloadingIterable', 'merge_shard_groups', - 'with_id', 'edit_files_for_hf_compatibility', 'SpecificWarningFilter', 'download_from_http_fileserver', 'download_from_hf_hub', 'download_from_oras', + 'maybe_create_mosaicml_logger', + 'find_mosaicml_logger', + 'log_eval_analytics', + 'log_train_analytics', 'load_prompts', 'load_prompts_from_file', - 'VersionedDeprecationWarning', + 'TypedRegistry', 'create_registry', 'construct_from_registry', - 'TypedRegistry', - 'find_mosaicml_logger', - 'log_eval_analytics', - 'log_train_analytics', - 'maybe_create_mosaicml_logger', + 'import_file', + 'save_registry', + 'VersionedDeprecationWarning', + 'ExperimentalWarning', + 'experimental_function', + 'experimental_class', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index a8c660df70..5d60cb0a1f 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -6,14 +6,13 @@ import logging import os import re +import warnings from collections import OrderedDict from typing import (Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union) import torch from composer.core import Algorithm, Callback, Evaluator -from composer.datasets.in_context_learning_evaluation import \ - get_icl_task_dataloader from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim.scheduler import ComposerScheduler @@ -27,8 +26,11 @@ from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.eval.datasets.in_context_learning_evaluation import \ + get_icl_task_dataloader from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.registry_utils import construct_from_registry +from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -496,8 +498,15 @@ def _validate_cfg(icl_cfg: DictConfig): icl_cfg.metric_names = [ 'InContextLearningMultipleChoiceAccuracy' ] - elif icl_cfg.icl_task_type == 'question_answering': - icl_cfg.metric_names = ['InContextLearningQAAccuracy'] + elif icl_cfg.icl_task_type == 'generation_task_with_answers' or icl_cfg.icl_task_type == 'question_answering': + if icl_cfg.icl_task_type == 'question_answering': + warnings.warn( + VersionedDeprecationWarning( + "ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'", + 'v0.9.0')) + icl_cfg.metric_names = [ + 'InContextLearningGenerationExactMatchAccuracy' + ] elif icl_cfg.icl_task_type == 'code_evaluation': icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy'] else: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index d2c3b733c0..a4fd005c3a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights log = logging.getLogger(__name__) @@ -131,7 +132,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set ffn_config.device_mesh to fsdp_config.device_mesh if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ - 'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}: + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: # Raise ValueError if not using device mesh with MoE expert parallelism if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( 'moe_world_size', 1) > 1: diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index d50977c097..058e73b393 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -9,7 +9,6 @@ from composer.utils import ObjectStore __all__ = [ - 'with_id', 'merge_shard_groups', 'DownloadingIterable', ] @@ -94,7 +93,7 @@ def __init__( Args: object_names (List[str]): Names of objects to download output_folder (str): Local folder to write downloaded files to - object_store (Optiona[ObjectStore]): Object store to download from + object_store (Optional[ObjectStore]): Object store to download from """ self.object_names = object_names self.object_store = object_store diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index fe24d0eae6..7a6be2be29 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -5,6 +5,30 @@ from collections.abc import Mapping from typing import Any, Dict, List +__all__ = [ + 'MissingHuggingFaceURLSplitError', + 'NotEnoughDatasetSamplesError', + 'UnknownExampleTypeError', + 'TooManyKeysInExampleError', + 'NotEnoughChatDataError', + 'ConsecutiveRepeatedChatRolesError', + 'InvalidLastChatMessageRoleError', + 'IncorrectMessageKeyQuantityError', + 'InvalidRoleError', + 'InvalidContentTypeError', + 'InvalidPromptTypeError', + 'InvalidResponseTypeError', + 'InvalidPromptResponseKeysError', + 'InvalidFileExtensionError', + 'UnableToProcessPromptResponseError', + 'ClusterDoesNotExistError', + 'FailedToCreateSQLConnectionError', + 'FailedToConnectToDatabricksError', + 'InputFolderMissingDataError', + 'OutputFolderNotEmptyError', + 'MisconfiguredHfDatasetError', +] + # Finetuning dataloader exceptions class MissingHuggingFaceURLSplitError(ValueError): @@ -204,3 +228,13 @@ def __init__(self, output_folder: str) -> None: self.output_folder = output_folder message = f'{output_folder} is not empty. Please remove or empty it and retry.' super().__init__(message) + + +class MisconfiguredHfDatasetError(ValueError): + """Error thrown when a HuggingFace dataset is misconfigured.""" + + def __init__(self, dataset_name: str, split: str) -> None: + self.dataset_name = dataset_name + self.split = split + message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. Please check your dataset format and make sure you can load your dataset locally.' + super().__init__(message) diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 5a198bc8df..3903a9bed3 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -132,7 +132,8 @@ def edit_files_for_hf_compatibility( flatten_imports_prefix: Sequence[str] = ('llmfoundry',), remove_imports_prefix: Sequence[str] = ('composer', 'omegaconf', 'llmfoundry.metrics', - 'llmfoundry.utils.builders'), + 'llmfoundry.eval', + 'llmfoundry.utils.builders') ) -> None: """Edit files to be compatible with Hugging Face Hub. diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index a88e02a33a..3707da3883 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -249,7 +249,7 @@ def download_from_oras(model: str, credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three files: `username`, `password`, and `registry`, each of which contains the corresponding credential. save_dir (str): Path to the directory where files will be downloaded. - tokenizer_only (bool): If true, only download the tokenzier files. + tokenizer_only (bool): If true, only download the tokenizer files. concurrency (int): The number of concurrent downloads to run. """ if shutil.which(ORAS_CLI) is None: diff --git a/llmfoundry/utils/mosaicml_logger_utils.py b/llmfoundry/utils/mosaicml_logger_utils.py index e54f11ce32..b4a40821ed 100644 --- a/llmfoundry/utils/mosaicml_logger_utils.py +++ b/llmfoundry/utils/mosaicml_logger_utils.py @@ -10,6 +10,13 @@ MOSAICML_PLATFORM_ENV_VAR) from omegaconf import DictConfig, ListConfig +__all__ = [ + 'maybe_create_mosaicml_logger', + 'find_mosaicml_logger', + 'log_eval_analytics', + 'log_train_analytics', +] + _MODEL_KEYS_TO_LOG = [ 'pretrained_model_name_or_path', 'pretrained', diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index d9c23e6f26..1604a8a91f 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -1,9 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import functools import importlib.util import os +from contextlib import contextmanager from pathlib import Path from types import ModuleType from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, @@ -11,7 +13,13 @@ import catalogue -__all__ = ['TypedRegistry', 'create_registry', 'construct_from_registry'] +__all__ = [ + 'TypedRegistry', + 'create_registry', + 'construct_from_registry', + 'import_file', + 'save_registry', +] T = TypeVar('T') TypeBoundT = TypeVar('TypeBoundT', bound=Type) @@ -143,7 +151,7 @@ def construct_from_registry( ) if post_validation_function is not None: - post_validation_function(registered_constructor) + post_validation_function(constructed_item) return constructed_item @@ -174,3 +182,13 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index 6c9106b2e7..fb0046f938 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -6,6 +6,9 @@ __all__ = [ 'VersionedDeprecationWarning', + 'ExperimentalWarning', + 'experimental_function', + 'experimental_class', ] diff --git a/mcli/README.md b/mcli/README.md index ced3c42adc..59fb723f57 100644 --- a/mcli/README.md +++ b/mcli/README.md @@ -23,4 +23,4 @@ All the details of multi-gpu and multi-node orchestration are handled automatica ## Using the MosaicML Python SDK to launch runs You can also use the [Python SDK](https://mcli.docs.mosaicml.com/en/stable/python/hello_world.html) to launch MosaicML platform jobs. -This can be used to programatically sweep hyperparameters or orchestrate training runs within a larger pipeline. +This can be used to programmatically sweep hyperparameters or orchestrate training runs within a larger pipeline. diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index 33a891c058..3c7c62f7d4 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -43,7 +43,7 @@ parameters: name: mpt_causal_lm init_device: meta d_model: 2048 - n_heads: 16 # Modified 24->16 so that d_head == 128 to statisfy FlashAttention + n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention n_layers: 24 expansion_ratio: 4 max_seq_len: ${max_seq_len} diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index e78e76a912..fb6bde4115 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -59,7 +59,7 @@ def parse_args() -> Namespace: '--skip-preprocessing', action='store_true', help= - 'Whether to skip preprocesing (e.g., if the dataset is already formatted correctly)' + 'Whether to skip preprocessing (e.g., if the dataset is already formatted correctly)' ) parser.add_argument( '--out_root', diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index be986fc24d..636e85abed 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -191,8 +191,8 @@ def get_task_args( input_folder (str): Folder of text files to process n_groups (int): Number of groups to split the object names into tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concantenate up to this many tokens - eos_text (str): Textend to append to each example to separate concatenated samples + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing @@ -219,7 +219,7 @@ def get_task_args( def download_and_convert_starargs(args: Tuple): """Helper function to call download_and_convert with star args. - This helps us use download_and_convert with mutiprocessing. + This helps us use download_and_convert with multiprocessing. """ return download_and_convert(*args) @@ -236,15 +236,15 @@ def download_and_convert( compression: str, trust_remote_code: bool, ): - """Downloads and converts text fies to MDS format. + """Downloads and converts text files to MDS format. Args: file_names (List[str]): Files to process output_folder (str): Folder to write MDS shards to input_folder (str): Folder of text files to process tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concantenate up to this many tokens - eos_text (str): Textend to append to each example to separate concatenated samples + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing @@ -375,8 +375,8 @@ def convert_text_to_mds( tokenizer_name (str): Name of tokenizer to use output_folder (str): Folder to write MDS shards to input_folder (str): Folder of text files to process - concat_tokens (int): Concantenate up to this many tokens - eos_text (str): Textend to append to each example to separate concatenated samples + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing diff --git a/scripts/eval/README.md b/scripts/eval/README.md index c8e4ed44b2..3a748066ec 100644 --- a/scripts/eval/README.md +++ b/scripts/eval/README.md @@ -49,7 +49,7 @@ In order to do ICL evaluation you must specify a set of benchmarks you'd like to #### ICL task YAML format -Your YAML must have a config section entitled `icl_tasks` specifying the benchmarks to evaluate againts, this can either be a list of dictionaries of the form +Your YAML must have a config section entitled `icl_tasks` specifying the benchmarks to evaluate against, this can either be a list of dictionaries of the form ```jsx icl_tasks: @@ -145,7 +145,8 @@ You can use the default `icl_tasks` and `eval_gauntlet` configs or specify your ICL evaluation measures a model’s ability to solve novel problems by being provided examples in-context without ever being specifically trained to answer such questions. -Composer supports a number of different standard ICL formats and allows users to upload their own datasets that correspond to those formats. +We support a number of standard ICL formats and allow users to upload their own datasets that correspond to these formats. All of our ICL task types are implemented in `llm-foundry/llmfoundry/eval/datasets/in_context_learning_evaluation.py` while all of our ICL +metrics are implemented in `llm-foundry/llmfoundry/eval/metrics/nlp.py`. You can see which metrics work with which task types in the `llmfoundry.utils.builders.build_icl_evaluators` helper function. This document explains the ICL formats compatible with [Composer](https://github.com/mosaicml/composer), summarizes how to add new datasets in those formats, and catalogs the datasets currently used by the research team to evaluate models. @@ -153,19 +154,19 @@ This document explains the ICL formats compatible with [Composer](https://github ## Supported ICL formats -Composer currently supports five ICL formats: +llm-foundry currently supports five ICL formats: -1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L103) -2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L293) -3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L444) -4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L676) -5. [InContextLearningCodeEvalDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L852) +1. InContextLearningGenerationTaskWithAnswersDataset +2. InContextLearningLMTaskDataset +3. InContextLearningMultipleChoiceTaskDataset +4. InContextLearningSchemaTaskDataset +5. InContextLearningCodeEvalDataset ---- -### InContextLearningQATaskDataset +### InContextLearningGenerationTaskWithAnswersDataset -The ICL question answering (QA) task supports free response question answering evaluation using the model’s generate function. A QA dataset consists of a list of JSONs containing a question (under the key `context`), a correct answer (under the key `answer`), and a list of alternative spellings of the answer that would be considered permissible (under the key `aliases`). The QA task works with the NLP metric: [InContextLearningQAAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningQAAccuracy.html) which assigns a model's output to be "correct" if, conditioned on the context, the model's generate method produces a string that is a normalized prefix for either the `answer` or any of the `aliases`. +The ICL generation with answers task supports free response generation evaluation using the model’s generate function. A generation dataset consists of a list of JSONs containing a prompt (under the key `context`), a correct answer (under the key `answer`), and a list of alternative answers that would be considered permissible (under the key `aliases`). The generation task works with the NLP metric: InContextLearningGenerationExactMatchAccuracy which assigns a model's output to be "correct" if, conditioned on the context, the model's generate method produces a string that is a normalized prefix for either the `answer` or any of the `aliases`. Required keys for each datum: * `context`: str @@ -178,7 +179,7 @@ An example datum is below: {"context": "What star sign is Jamie Lee Curtis?", "answer": "Scorpio", "aliases": ["Scorpio", "Skorpio"]} ``` -The QA task expects a **prompt string**, a **continuation delimiter** to separate questions from answers, an **example delimiter** to separate few shot examples from one another, and a **question prelimiter** to put before each question. If using the following settings, with 2 examples in context, the above datum may be rendered to the model as: +The generation task expects a **prompt string**, a **continuation delimiter** to separate questions from answers, an **example delimiter** to separate few shot examples from one another, and a **question prelimiter** to put before each question. If using the following settings, with 2 examples in context, the above datum may be rendered to the model as: ```jsx prompt_string: "Answer the following trivia question:\n", example_delimiter: "\n", continuation_delimiter: " Answer: ", question_prelimiter: "Question: " @@ -203,9 +204,9 @@ Below is a complete YAML section that works with the TriviaQA dataset in [`scrip - 5 - 10 batch_size: 4 - icl_task_type: question_answering + icl_task_type: generation_task_with_answers metric_names: - - InContextLearningQAAccuracy + - InContextLearningGenerationExactMatchAccuracy prompt_string: '' # this goes at the beginning of each input example_delimiter: "\n" # this goes between fewshot examples continuation_delimiter: ' ' # this separates questions from answers @@ -215,7 +216,7 @@ Below is a complete YAML section that works with the TriviaQA dataset in [`scrip ### InContextLearningLMTaskDataset -The ICL language modeling (LM) task assesses the model’s ability to predict a precise sequence of tokens (called a continuation) following some context using the model’s `forward` function. An LM dataset consists of a list of JSONs containing a context (under the key `context`) and a continuation (under the key `continuation` that the model must correctly predict conditioned on the context. The LM task uses the NLP metric [InContextLearningLMAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningLMAccuracy.html), which assigns a model's output to be "correct" if, conditioned on the context tokens, the model's argmax output logits exactly match the tokens in the continuation. +The ICL language modeling (LM) task assesses the model’s ability to predict a precise sequence of tokens (called a continuation) following some context using the model’s `forward` function. An LM dataset consists of a list of JSONs containing a context (under the key `context`) and a continuation (under the key `continuation` that the model must correctly predict conditioned on the context. The LM task uses the NLP metric InContextLearningLMAccuracy, which assigns a model's output to be "correct" if, conditioned on the context tokens, the model's argmax output logits exactly match the tokens in the continuation. Required keys for each datum: * `context`: str @@ -256,7 +257,7 @@ Below is a YAML section that works with the Lambada OpenAI dataset in [`scripts/ ### InContextLearningMultipleChoiceTaskDataset -The ICL multiple choice (MC) task assesses the model’s ability to answer multiple choice questions by assigning highest per token probability to the correct answer. An MC dataset consists of a list of JSONs containing a query (under the key `query`), a list of choices (under the key `choices`), and the index indicating the correct answer (under the key `gold`). The MC task works with the NLP metric [InContextLearningMultipleChoiceAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningMultipleChoiceAccuracy.html), which separately runs the model's `forward()` method on the query prepended to each choice, and then determines the model to be correct if the correct choice has the lowest per token perplexity conditioned on the query. +The ICL multiple choice (MC) task assesses the model’s ability to answer multiple choice questions by assigning highest per token probability to the correct answer. An MC dataset consists of a list of JSONs containing a query (under the key `query`), a list of choices (under the key `choices`), and the index indicating the correct answer (under the key `gold`). The MC task works with the NLP metric InContextLearningMultipleChoiceAccuracy, which separately runs the model's `forward()` method on the query prepended to each choice, and then determines the model to be correct if the correct choice has the lowest per token perplexity conditioned on the query. Required keys for each datum: * `query`: str @@ -294,7 +295,6 @@ Below is a YAML section that works with the HellaSwag dataset in [`scripts/eval/ icl_task_type: multiple_choice metric_names: - InContextLearningMultipleChoiceAccuracy - - InContextLearningMCExpectedCalibrationError prompt_string: '' # this goes at the beginning of each input example_delimiter: "\n" # this goes between fewshot examples continuation_delimiter: ' ' # this separates questions from answers @@ -306,7 +306,7 @@ Below is a YAML section that works with the HellaSwag dataset in [`scripts/eval/ The ICL schema task assesses the model’s ability to determine which of some set of possible contexts (under the key `context_options`) makes a sequence of tokens (under the key `continuation`) most likely, with the correct context indicated by "gold". This task is based on [A Simple Method for Commonsense Reasoning](https://arxiv.org/abs/1806.02847). -The schema task works with the NLP metric [InContextLearningMultipleChoiceAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningMultipleChoiceAccuracy.html), which separately runs the model's `forward()` method on each context option prepended to the continuation and rates the model correct if it assigns minimum per token perplexity to the continuation conditioned on the true context. +The schema task works with the NLP metric InContextLearningMultipleChoiceAccuracy, which separately runs the model's `forward()` method on each context option prepended to the continuation and rates the model correct if it assigns minimum per token perplexity to the continuation conditioned on the true context. Required keys for each datum: * query: str diff --git a/scripts/eval/yamls/tasks_v0.1.yaml b/scripts/eval/yamls/tasks_v0.1.yaml index 44f031ae3a..6546b13dd7 100644 --- a/scripts/eval/yamls/tasks_v0.1.yaml +++ b/scripts/eval/yamls/tasks_v0.1.yaml @@ -10,12 +10,12 @@ icl_tasks: label: triviaqa_sm_sub dataset_uri: eval/local_data/world_knowledge/triviaqa_sm_sub.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers - label: gsm8k dataset_uri: eval/local_data/symbolic_problem_solving/gsm8k.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " question_prelimiter: "Q: " @@ -23,21 +23,21 @@ icl_tasks: label: agi_eval_sat_math dataset_uri: eval/local_data/symbolic_problem_solving/agi_eval_sat_math.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: aqua dataset_uri: eval/local_data/symbolic_problem_solving/aqua.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: svamp dataset_uri: eval/local_data/symbolic_problem_solving/svamp.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers continuation_delimiter: "\nUsing the formula below:\n" cot_delimiter: " #### " question_prelimiter: "Q: " diff --git a/scripts/eval/yamls/tasks_v0.2.yaml b/scripts/eval/yamls/tasks_v0.2.yaml index e23b4df1a5..ae39d87fbd 100644 --- a/scripts/eval/yamls/tasks_v0.2.yaml +++ b/scripts/eval/yamls/tasks_v0.2.yaml @@ -10,12 +10,12 @@ icl_tasks: label: triviaqa_sm_sub dataset_uri: eval/local_data/world_knowledge/triviaqa_sm_sub.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers - label: gsm8k dataset_uri: eval/local_data/symbolic_problem_solving/gsm8k.jsonl num_fewshot: [8, 5] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " question_prelimiter: "Q: " @@ -23,21 +23,21 @@ icl_tasks: label: agi_eval_sat_math dataset_uri: eval/local_data/symbolic_problem_solving/agi_eval_sat_math.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: aqua dataset_uri: eval/local_data/symbolic_problem_solving/aqua.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: " #### " continuation_delimiter: "\nA: Let's think step by step. " - label: svamp dataset_uri: eval/local_data/symbolic_problem_solving/svamp.jsonl num_fewshot: [5] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers continuation_delimiter: "\nUsing the formula below:\n" cot_delimiter: " #### " question_prelimiter: "Q: " diff --git a/scripts/eval/yamls/tasks_v0.3.yaml b/scripts/eval/yamls/tasks_v0.3.yaml index e02178710e..396ceaaf85 100644 --- a/scripts/eval/yamls/tasks_v0.3.yaml +++ b/scripts/eval/yamls/tasks_v0.3.yaml @@ -3,7 +3,7 @@ icl_tasks: label: gsm8k dataset_uri: eval/local_data/symbolic_problem_solving/gsm8k_prepended_8shot.jsonl num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: "The answer is " continuation_delimiter: "\n\nA:" question_prelimiter: "" @@ -15,13 +15,13 @@ icl_tasks: label: triviaqa_sm_sub dataset_uri: eval/local_data/world_knowledge/triviaqa_sm_sub.jsonl num_fewshot: [3] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers do_normalization: true - label: svamp dataset_uri: eval/local_data/symbolic_problem_solving/svamp.jsonl num_fewshot: [5] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers cot_delimiter: "The answer is " continuation_delimiter: "\n\nA:" question_prelimiter: "Question: " diff --git a/scripts/inference/benchmarking/README.md b/scripts/inference/benchmarking/README.md index 837154a977..b3f01256ac 100644 --- a/scripts/inference/benchmarking/README.md +++ b/scripts/inference/benchmarking/README.md @@ -28,7 +28,7 @@ LLM inference consists of two stages: _prefill_ and _decode_. It's important to During _prefill_, the model processes the input tokens/prompt/context. This is done in a single forward pass, making this stage fast, with excellent use of GPU hardware (ie. high Model Flop Utilization aka [MFU](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train/benchmarking#mfu)). Typically, if people talk about LLM inference being slow, this is _not_ the stage that they are referring to. -During _decode_, the model generates output tokens one at a time, i.e. autoregressively. This requires making N forward passes of the model for N tokens. This stage is slow and inefficient, because it requires moving gigabytes of model weights and pre-filled values for every single forward pass. Here, latency scales (mostly) linearly with the number of output tokens. Why mostly linear? When generating long sequences, the quadratic memory and compute complexity of the attention operation become more prominant. +During _decode_, the model generates output tokens one at a time, i.e. autoregressively. This requires making N forward passes of the model for N tokens. This stage is slow and inefficient, because it requires moving gigabytes of model weights and pre-filled values for every single forward pass. Here, latency scales (mostly) linearly with the number of output tokens. Why mostly linear? When generating long sequences, the quadratic memory and compute complexity of the attention operation become more prominent. ##### KV cache @@ -132,5 +132,5 @@ The benchmark script supports calling models directly from huggingface (using `h The analysis is done on a single A100 80GB GPU, with input length 512, and output length 64, while varying the batch size. As in previous sections, the batch sizes swept are 1, 2, 4, 8, 16, 32, 64, unless the GPU ran out of memory, in which case that point is not shown. As seen here, both MPT-7B and MPT-30B are among the fastest for inference in the open-source community, with MPT-30B being faster than the respective LLAMA-30B model. -Among the 7B models, Falcon-7B tends to have higher througput at higher latencies than MPT-7B, though MPT-7B has higher throughput at lower latencies. +Among the 7B models, Falcon-7B tends to have higher throughput at higher latencies than MPT-7B, though MPT-7B has higher throughput at lower latencies. Previously, we found that Falcon-7b was significantly slower than both MPT-7B and LLAMA-7B. This slow speed was due to the KV-cache not being used properly during generation, however this appears to be [fixed](https://huggingface.co/tiiuae/falcon-7b/tree/main) as of July 13, 2022. diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index 1ba1123c86..9d1841b12f 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -160,7 +160,7 @@ def export_to_onnx( atol=1e-2, msg=f'output mismatch between the orig and onnx exported model', ) - print('exported model ouptut matches with unexported model!!') + print('exported model output matches with unexported model!!') if save_object_store is not None: print('Uploading files to object storage...') diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index e78fecf59b..e6f9ae1448 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -42,7 +42,7 @@ def parse_args() -> Namespace: '-i', '--inputs', nargs='+', - help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\ + help=f'List of strings, local data files (starting with {utils.PROMPTFILE_PREFIX}),' +\ ' and/or remote object stores' ) parser.add_argument( diff --git a/scripts/inference/run_mpt_with_ft.py b/scripts/inference/run_mpt_with_ft.py index 10ccf6b78b..61d9f68d2c 100644 --- a/scripts/inference/run_mpt_with_ft.py +++ b/scripts/inference/run_mpt_with_ft.py @@ -197,7 +197,7 @@ def main(): type=int, default=0, choices=[0, 1, 2], - help='Whether to compute the cumulative log probsbility of sentences.' + + help='Whether to compute the cumulative log probability of sentences.' + ' 0: do not return the cumulative log probs' + ' 1: return the cumulative log probs of generated sequences' + ' 2: return the cumulative log probs of sequences') diff --git a/scripts/train/README.md b/scripts/train/README.md index 36974ec943..6730cb793b 100644 --- a/scripts/train/README.md +++ b/scripts/train/README.md @@ -276,7 +276,7 @@ If the dataset requires a [custom preprocessing function](#custom-data-preproces train_loader: name: finetuning dataset: - hf_name: mosaiml/doge-facts + hf_name: mosaicml/doge-facts preprocessing_fn: my_data.formatting:dogefacts_prep_fn split: train ... @@ -402,7 +402,7 @@ so you should be able to run the exact same YAML on 8 or 16 or 256 GPUs and get This is nice because it means you can write device-count-agnostic training configs, and not worry about OOM-ing or accidentally changing the optimization math. -In previous blogposts ([1](https://www.mosaicml.com/blog/farewell-oom), [2](https://www.mosaicml.com/blog/billion-parameter-gpt-training-made-easy)) +In previous blog posts ([1](https://www.mosaicml.com/blog/farewell-oom), [2](https://www.mosaicml.com/blog/billion-parameter-gpt-training-made-easy)) we also demonstrated auto microbatching, which takes things a step further by letting Composer determine the `device_train_microbatch_size` on its own. This makes our configs not only device-count-agnostic, but hardware-agnostic too! You can try out this feature by setting `device_train_microbatch_size: auto`, but bear in mind that FSDP support is still in alpha mode diff --git a/scripts/train/benchmarking/README.md b/scripts/train/benchmarking/README.md index 5414cdc7bf..f5da10ec6a 100644 --- a/scripts/train/benchmarking/README.md +++ b/scripts/train/benchmarking/README.md @@ -20,7 +20,7 @@ python submit_benchmarks.py --cluster [your_mosaicml_cluster] ARGS --RUN can be used to sweep a larger set of configurations. For example usage of `submit_benchmarks.py` see `sweep.sh` which lists all benchmarks in the tables. > **Note** -> The `collect_results.py` will by default find all runs with `tput` in the run name. To customize this project tag, use `--project` in both the submissing and collection scripts. +> The `collect_results.py` will by default find all runs with `tput` in the run name. To customize this project tag, use `--project` in both the submission and collection scripts. ## MFU and HFU @@ -55,7 +55,7 @@ hfu* = 4 * flops_per_seq * seq_per_sec / (gpu_num * GPU_AVAILABLE_FLOPS) hfu = (4 * flops_per_seq + 4 * attn_flops_per_seq) * seq_per_sec / (gpu_num * GPU_AVAILABLE_FLOPS) ``` -Note that these are approximations. Actual HFU would be higher since it includes the floating point operations for normalization, activation, and residual lyaers, as well as **all** recomputation. For example, our models use Flash Attention, which requires including an extra recompute factor for its recomputation in the forward pass. Therefore, the attention multipler would be 5 instead of 4. +Note that these are approximations. Actual HFU would be higher since it includes the floating point operations for normalization, activation, and residual layers, as well as **all** recomputation. For example, our models use Flash Attention, which requires including an extra recompute factor for its recomputation in the forward pass. Therefore, the attention multiplier would be 5 instead of 4. ## Results @@ -65,7 +65,7 @@ python submit_benchmarks.py -m 13b.yaml 30b.yaml -t fp16 -b 21 21 -s 11 14 --RUN ``` This will run 8 configs for 12 steps to get throughput numbers. `python collect_results.py` can then be used to parse all output training logs and create the tables below. -Our microbatching engine enables microbatch sizes that do not divde Global Batchsize while being mathematically faithful to the global batch size. For example, a total batch size of 48, and a micro batch of 11, means we will accumulate gradients across microbatches of 11, 11, 11, 11, 4. +Our microbatching engine enables microbatch sizes that do not divide global batch size while being mathematically faithful to the global batch size. For example, a total batch size of 48, and a micro batch of 11, means we will accumulate gradients across microbatches of 11, 11, 11, 11, 4. [comment]: # (TODO: Update tables with torch 2.0 after next Composer release) diff --git a/scripts/train/benchmarking/collect_results.py b/scripts/train/benchmarking/collect_results.py index d3691e951c..151286dbc6 100644 --- a/scripts/train/benchmarking/collect_results.py +++ b/scripts/train/benchmarking/collect_results.py @@ -150,8 +150,8 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]: d_model = run.submitted_config.parameters['model']['d_model'] n_layers = run.submitted_config.parameters['model']['n_layers'] - # mfu is approximated using thoughtput and param count - # the number of paramters is approximately the number of multiply-accumulates (MAC) in the network + # mfu is approximated using throughput and param count + # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param # there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param # this gets us FLOPs / token diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index aff570e3d4..5e83ae41b7 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -205,7 +205,7 @@ def get_global_train_batch_sizes(max_seq_len: int, if batch_sizes is None: batch_sizes = [] if pows: - # global batch size in tokens (defualt: .5M thru 8M) + # global batch size in tokens (default: .5M thru 8M) global_train_token_counts = [2**n for n in range(pows[0], pows[1] + 1)] batch_sizes += [t // max_seq_len for t in global_train_token_counts ] # global batch size in samples diff --git a/scripts/train/train.py b/scripts/train/train.py index 96066d5a1d..a49ae4e26d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -12,7 +12,6 @@ import torch from composer import Trainer from composer.core.callback import Callback -from composer.metrics.nlp import InContextLearningMetric from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, cyclic_schedule) from composer.utils import dist, get_device, reproducibility @@ -20,12 +19,14 @@ from omegaconf import OmegaConf as om from rich.traceback import install +from llmfoundry.eval.metrics.nlp import InContextLearningMetric from llmfoundry.utils import (find_mosaicml_logger, log_train_analytics, maybe_create_mosaicml_logger) install() from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, build_evaluators, @@ -102,7 +103,7 @@ def validate_config(cfg: DictConfig): ) if cfg.model.get('ffn_config', {}).get('ffn_type', - 'mptmlp') in ('mb_moe', 'mb_dmoe'): + 'mptmlp') in ffns_with_megablocks: moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) use_orig_params = cfg.get('fsdp_config', diff --git a/scripts/train/yamls/finetune/1b_local_data_sft.yaml b/scripts/train/yamls/finetune/1b_local_data_sft.yaml index d7b9db10d4..46141ce5ab 100644 --- a/scripts/train/yamls/finetune/1b_local_data_sft.yaml +++ b/scripts/train/yamls/finetune/1b_local_data_sft.yaml @@ -16,7 +16,7 @@ model: name: mpt_causal_lm init_device: meta d_model: 2048 - n_heads: 16 # Modified 24->16 so that d_head == 128 to statisfy FlashAttention + n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention n_layers: 24 expansion_ratio: 4 max_seq_len: ${max_seq_len} diff --git a/scripts/train/yamls/pretrain/mpt-1b.yaml b/scripts/train/yamls/pretrain/mpt-1b.yaml index 3744a455a8..effa60c59e 100644 --- a/scripts/train/yamls/pretrain/mpt-1b.yaml +++ b/scripts/train/yamls/pretrain/mpt-1b.yaml @@ -11,7 +11,7 @@ model: name: mpt_causal_lm init_device: meta d_model: 2048 - n_heads: 16 # Modified 24->16 so that d_head == 128 to statisfy FlashAttention + n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention n_layers: 24 expansion_ratio: 4 max_seq_len: ${max_seq_len} diff --git a/setup.py b/setup.py index 086e759384..eb6c88af9e 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,9 @@ content = f.read() # regex: '__version__', whitespace?, '=', whitespace, quote, version, quote # we put parens around the version so that it becomes elem 1 of the match -expr = re.compile(r"""^__version__\W+=\W+['"]([0-9\.]*)['"]""", re.MULTILINE) +expr = re.compile( + r"""^__version__\s*=\s*['"]([0-9]+\.[0-9]+\.[0-9]+(?:\.\w+)?)['"]""", + re.MULTILINE) repo_version = expr.findall(content)[0] # Use repo README for PyPi description @@ -51,10 +53,10 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22', - 'mlflow>=2.10,<3', + 'mosaicml[libcloud,wandb,oci,gcs]>=0.21.3,<0.22', + 'mlflow>=2.12.1,<2.13', 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.39.3,<4.40', + 'transformers>=4.40,<4.41', 'mosaicml-streaming>=0.7.5,<0.8', 'torch>=2.2.1,<2.3', 'datasets>=2.16,<2.17', diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index ff885ac735..fe58a44459 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -147,6 +147,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): tuple) +@pytest.mark.gpu def test_validate_config(): conf_path: str = os.path.join( REPO_DIR, diff --git a/tests/callbacks/test_eval_gauntlet_callback.py b/tests/callbacks/test_eval_gauntlet_callback.py index 3a1e371ab8..8d9938e3a1 100644 --- a/tests/callbacks/test_eval_gauntlet_callback.py +++ b/tests/callbacks/test_eval_gauntlet_callback.py @@ -9,9 +9,9 @@ import torch from composer.core import State from composer.loggers import InMemoryLogger, Logger -from composer.metrics import InContextLearningLMAccuracy from transformers import AutoTokenizer +from llmfoundry.eval.metrics.nlp import InContextLearningLMAccuracy from llmfoundry.utils.builders import build_icl_data_and_gauntlet diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index c99ae6baf2..3eb5e3773d 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -21,8 +21,7 @@ from omegaconf import OmegaConf as om from streaming import MDSWriter -from llmfoundry import build_finetuning_dataloader -from llmfoundry.data import build_dataloader +from llmfoundry.data import build_dataloader, build_finetuning_dataloader from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX, validate_target_settings) from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, @@ -42,6 +41,7 @@ InvalidPromptTypeError, InvalidResponseTypeError, InvalidRoleError, + MisconfiguredHfDatasetError, NotEnoughDatasetSamplesError, TooManyKeysInExampleError, UnknownExampleTypeError) @@ -268,6 +268,47 @@ def test_sequence_id_wrapper(eos_token_id: Optional[int], raise NotImplementedError() +def test_invalid_jsonl_data(): + max_seq_len = 2 + decoder_only_format = True + packing_ratio = 'auto' + allow_pad_trimming = False + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': 'iamroot/chat_malformatted_examples', + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': decoder_only_format, + 'allow_pad_trimming': allow_pad_trimming, + 'packing_ratio': packing_ratio, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}) + + device_batch_size = 2 + + expected_keys = ['input_ids', 'attention_mask', 'labels'] + if not decoder_only_format: + expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] + + with pytest.raises(MisconfiguredHfDatasetError): + build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader + + @pytest.mark.parametrize('use_chat_formatting', [True, False]) @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('allow_pad_trimming', [True, False]) diff --git a/tests/data/test_tasks.yaml b/tests/data/test_tasks.yaml index cec7984320..cf02ffcbbb 100644 --- a/tests/data/test_tasks.yaml +++ b/tests/data/test_tasks.yaml @@ -20,4 +20,4 @@ icl_tasks: label: triviaqa dataset_uri: scripts/eval/local_data/world_knowledge/triviaqa_small.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [0, 1] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 632a79dac9..756912342f 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -252,7 +252,7 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): def test_tokenize_no_labels_bos_pr(): # This tokenizer automatically adds bos tokens tokenizer = transformers.AutoTokenizer.from_pretrained( - 'mistralai/Mixtral-8x7B-v0.1') + 'ai21labs/Jamba-v0.1', add_bos_token=True) example = {'prompt': 'prompt', 'response': 'response'} @@ -270,7 +270,7 @@ def test_tokenize_no_labels_bos_pr(): # This tokenizer does not have the add_bos_token attribute tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') - assert not hasattr(tokenizer, 'add_bos_token') + assert not tokenizer.add_bos_token tokenized_example = tokenize_formatted_example(example, tokenizer) diff --git a/tests/eval/local_data/gsm8k_small.jsonl b/tests/eval/local_data/gsm8k_small.jsonl new file mode 100644 index 0000000000..522966c902 --- /dev/null +++ b/tests/eval/local_data/gsm8k_small.jsonl @@ -0,0 +1,4 @@ +{"context": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", "chain_of_thought": "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", "answer": "18"} +{"context": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", "chain_of_thought": "It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric", "answer": "3"} +{"context": "Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?", "chain_of_thought": "The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000", "answer": "70000"} +{"context": "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?", "chain_of_thought": "He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters", "answer": "540"} diff --git a/tests/eval/local_data/hellaswag_small.jsonl b/tests/eval/local_data/hellaswag_small.jsonl new file mode 100644 index 0000000000..d2e37771c9 --- /dev/null +++ b/tests/eval/local_data/hellaswag_small.jsonl @@ -0,0 +1,4 @@ +{"query": "Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. Then", "choices": [", the man adds wax to the windshield and cuts it.", ", a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.", ", the man puts on a christmas coat, knitted with netting.", ", the man continues removing the snow on his car."], "gold": 3} +{"query": "Baking cookies: A female chef in white uniform shows a stack of baking pans in a large kitchen presenting them. The pans", "choices": ["contain egg yolks and baking soda.", "are then sprinkled with brown sugar.", "are placed in a strainer on the counter.", "are filled with pastries and loaded into the oven."], "gold": 3} +{"query": "Baking cookies: A female chef in white uniform shows a stack of baking pans in a large kitchen presenting them. The pans are filled with pastries and loaded into the oven. A knife", "choices": ["is seen moving on a board and cutting out its contents.", "hits the peeled cheesecake, followed by sliced custard and still cooked ice cream.", "etches a shape into the inside of the baked pans.", "is used to cut cylinder shaped dough into rounds."], "gold": 3} +{"query": "Baking cookies: A tray of potatoes is loaded into the oven and removed. A large tray of cake is flipped over and placed on counter. A large tray of meat", "choices": ["is placed onto a baked potato.", ", ls, and pickles are placed in the oven.", "is poured into a midden.", "is prepared then it is removed from the oven by a helper when done."], "gold": 3} diff --git a/tests/eval/local_data/human_eval_small.jsonl b/tests/eval/local_data/human_eval_small.jsonl new file mode 100644 index 0000000000..850d46e031 --- /dev/null +++ b/tests/eval/local_data/human_eval_small.jsonl @@ -0,0 +1,4 @@ +{"task_id": "HumanEval/0", "prompt": "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", "entry_point": "has_close_elements", "canonical_solution": " for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\n", "test_inputs": ["([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3)", "([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05)", "([1.0, 2.0, 5.9, 4.0, 5.0], 0.95)", "([1.0, 2.0, 5.9, 4.0, 5.0], 0.8)", "([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1)", "([1.1, 2.2, 3.1, 4.1, 5.1], 1.0)", "([1.1, 2.2, 3.1, 4.1, 5.1], 0.5)"], "test_outputs": ["True", "False", "True", "False", "True", "True", "False"], "language": "python"} +{"task_id": "HumanEval/1", "prompt": "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", "entry_point": "separate_paren_groups", "canonical_solution": " result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate('(()()) ((())) () ((())()())') == [\n '(()())', '((()))', '()', '((())()())'\n ]\n assert candidate('() (()) ((())) (((())))') == [\n '()', '(())', '((()))', '(((())))'\n ]\n assert candidate('(()(())((())))') == [\n '(()(())((())))'\n ]\n assert candidate('( ) (( )) (( )( ))') == ['()', '(())', '(()())']\n", "test_inputs": ["('(()()) ((())) () ((())()())',)", "('() (()) ((())) (((())))',)", "('(()(())((())))',)", "('( ) (( )) (( )( ))',)"], "test_outputs": ["['(()())', '((()))', '()', '((())()())']", "['()', '(())', '((()))', '(((())))']", "['(()(())((())))']", "['()', '(())', '(()())']"], "language": "python"} +{"task_id": "HumanEval/2", "prompt": "\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", "entry_point": "truncate_number", "canonical_solution": " return number % 1.0\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate(3.5) == 0.5\n assert abs(candidate(1.33) - 0.33) < 1e-6\n assert abs(candidate(123.456) - 0.456) < 1e-6\n", "test_inputs": ["(3.5,)", "(1.33,)", "(123.456,)"], "test_outputs": ["0.5", "0.33000000000000007", "0.45600000000000307"], "language": "python"} +{"task_id": "HumanEval/3", "prompt": "from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", "entry_point": "below_zero", "canonical_solution": " balance = 0\n\n for op in operations:\n balance += op\n if balance < 0:\n return True\n\n return False\n", "test": "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(candidate):\n assert candidate([]) == False\n assert candidate([1, 2, -3, 1, 2, -3]) == False\n assert candidate([1, 2, -4, 5, 6]) == True\n assert candidate([1, -1, 2, -2, 5, -5, 4, -4]) == False\n assert candidate([1, -1, 2, -2, 5, -5, 4, -5]) == True\n assert candidate([1, -2, 2, -2, 5, -5, 4, -4]) == True\n", "test_inputs": ["([],)", "([1, 2, -3, 1, 2, -3],)", "([1, 2, -4, 5, 6],)", "([1, -1, 2, -2, 5, -5, 4, -4],)", "([1, -1, 2, -2, 5, -5, 4, -5],)", "([1, -2, 2, -2, 5, -5, 4, -4],)"], "test_outputs": ["False", "False", "True", "False", "True", "True"], "language": "python"} diff --git a/tests/eval/local_data/lambada_small.jsonl b/tests/eval/local_data/lambada_small.jsonl new file mode 100644 index 0000000000..5a0dc238ae --- /dev/null +++ b/tests/eval/local_data/lambada_small.jsonl @@ -0,0 +1,4 @@ +{"context": "With Tristran's next step he was standing beside a lake, and the candlelight shone brightly on the water; and then he was walking through the mountains, through lonely crags, where the candlelight was reflected in the eyes of the creatures of the high snows; and then he was walking through the clouds, which, while not entirely substantial, still supported his weight in comfort; and then, holding tightly to his candle, he was underground, and the candlelight glinted back at him from the wet cave walls; now he was in the mountains once more; and then he was on a road through wild forest, and he glimpsed a chariot being pulled by two goats, being driven by a woman in a red dress who looked, for the glimpse he got of her, the way Boadicea was drawn in his history books; and another step and he was in a leafy glen, and he could hear the chuckle of water as it splashed and sang its way into a small brook.\n\nHe took another step, but he was still in the", "continuation": "glen"} +{"context": "Todd replied: No I thought you looked familiar but I can’t recall. The stranger told Todd: I’m Enoch; we met in your dream. Todd looked back again, this time he realized it really was Enoch; Todd stopped on the side of the road, leaned back and tried to see if he was dreaming. When Enoch said: No Todd you’re not", "continuation": "dreaming"} +{"context": "The Librarian thumbed through the bundle of pages, stopping on the final sheet and began reading, “It is our conclusion that much of the work that is currently done in the Library can be out-sourced to contractors, particularly non-skill specific work such as shelving, stacking...”\nLucy gulped and Gillian began to open her mouth to protest again, but the Librarian carried on regardless, his voice becoming louder in order to drown out any potentially dissenting voices, “... blah, blah, blah. It is our recommendation that a downsizing of the non-essential and part-time members of staff would bring instant economy of scale benefits and would allow for the implementation of a new middle management structure.”\n“You mean sacrifice the troops to pay for the generals,” said", "continuation": "Gillian"} +{"context": "He was small, even for a dwarf, and his poor taste in sorcerous robes contrasted awkwardly with D’jebee’s elegant attire; her long, diaphanous gown and his chemical-stained, star-spangled robe clashed almost as much as her vacuous expression alongside his own visage, alive as it was with cunning and a twisted intelligence.\n\nD’jebee sighed with boredom.\n\n‘What is it, my love?’ Poldanyelz oozed with ersatz concern.\n\n‘I’m bored,’ D’jebee complained undiplomatically. ‘No one ever comes here. I never see anyone except you.’\n\nA shuffling from the main arch alerted her to the inaccuracy of her", "continuation": "statement"} diff --git a/tests/eval/local_data/mmlu_small.jsonl b/tests/eval/local_data/mmlu_small.jsonl new file mode 100644 index 0000000000..90eb402607 --- /dev/null +++ b/tests/eval/local_data/mmlu_small.jsonl @@ -0,0 +1,4 @@ +{"query": "Question: How is IP address spoofing detected?\n(A) Installing and configuring a IDS that can read the IP header (B) Comparing the TTL values of the actual and spoofed addresses (C) Implementing a firewall to the network (D) Identify all TCP sessions that are initiated but does not complete successfully\n", "gold": 1, "choices": ["A", "B", "C", "D"], "category": "computer_security"} +{"query": "Question: Which of the following is not an example of presentation layer issues?\n(A) Poor handling of unexpected input can lead to the execution of arbitrary instructions (B) Unintentional or ill-directed use of superficially supplied input (C) Cryptographic flaws in the system may get exploited to evade privacy (D) Weak or non-existent authentication mechanisms\n", "gold": 3, "choices": ["A", "B", "C", "D"], "category": "computer_security"} +{"query": "Question: Suppose Unix did not provide a way of passing file descriptors between processes, but still allowed inheriting file descriptors from a parent on fork and exec. What aspects of the OKWS design would break without file descriptor passing?\n1. It would be impossible for services to send messages to oklogd.\n2. It would be impossible for services to get a TCP connection to a database proxy.\n(A) True, True (B) False, False (C) True, False (D) False, True\n", "gold": 1, "choices": ["A", "B", "C", "D"], "category": "computer_security"} +{"query": "Question: Why would a ping sweep be used?\n(A) To identify live systems (B) To locate live systems (C) To identify open ports (D) To locate firewalls\n", "gold": 0, "choices": ["A", "B", "C", "D"], "category": "computer_security"} diff --git a/tests/eval/local_data/piqa_small.jsonl b/tests/eval/local_data/piqa_small.jsonl new file mode 100644 index 0000000000..07b1b27509 --- /dev/null +++ b/tests/eval/local_data/piqa_small.jsonl @@ -0,0 +1,4 @@ +{"choices": ["Pour it onto a plate", "Pour it into a jar"], "gold": 1, "query": "When boiling butter, when it's ready, you can"} +{"choices": ["Weld the metal together to get it to stay firmly in place", "Nail the metal together to get it to stay firmly in place"], "gold": 0, "query": "To permanently attach metal legs to a chair, you can"} +{"choices": ["leave a space before starting the writing", "press the spacebar"], "gold": 0, "query": "how do you indent something?"} +{"choices": ["move it up and down and side to side quickly.", "stir it very quickly."], "gold": 0, "query": "how do you shake something?"} diff --git a/tests/eval/local_data/pubmed_sm.jsonl b/tests/eval/local_data/pubmed_sm.jsonl new file mode 100644 index 0000000000..c39bab0b04 --- /dev/null +++ b/tests/eval/local_data/pubmed_sm.jsonl @@ -0,0 +1,4 @@ +{"context": "Context: PURPOSE. To assess whether eligibility to an adjuvant chemotherapy protocol in itself represents a good prognostic factor after radical cystectomy for bladder cancer.\nPATIENTS AND METHODS. Between April 1984 and May 1989, our institution entered 35 patients with invasive bladder cancer into the Swiss Group for Clinical and Epidemiological Cancer Research (SAKK) study 09/84. They were randomly assigned to either observation or three postoperative courses of cisplatin monotherapy after cystectomy. This study had a negative result. The outcome of these 35 patients (protocol group) was compared with an age- and tumor-stage-matched cohort (matched group; n = 35) who also underwent cystectomy during the same period, but were not entered into the SAKK study, as well as the remaining 57 patients treated during the study period for the same indication (remaining group).\nRESULTS. Median overall survival decreased from 76.3 months in the protocol group to 52.1 months in the matched group and to 20.3 months in the remaining group. The respective times of median recurrence-free survival were 67.2, 16.0, and 9.4 months. Tumor progression occurred in 46% of the protocol group compared with 69% in the matched group and 65% in the remaining group (P<.05). Cancer-related death was noted in 40% of the protocol group, 57% in the matched group, and 56% in the remaining group.\nQuestion: Is eligibility for a chemotherapy protocol a good prognostic factor for invasive bladder cancer after radical cystectomy?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} +{"context": "Context: BACKGROUND. This study was performed to describe the treatment plan modifications after a geriatric oncology clinic. Assessment of health and functional status and cancer assessment was performed in older cancer patients referred to a cancer center.\nPATIENTS AND METHODS. Between June 2004 and May 2005, 105 patients 70 years old or older referred to a geriatric oncology consultation at the Institut Curie cancer center were included. Functional status, nutritional status, mood, mobility, comorbidity, medication, social support, and place of residence were assessed. Oncology data and treatment decisions were recorded before and after this consultation. Data were analyzed for a possible correlation between one domain of the assessment and modification of the treatment plan.\nRESULTS. Patient characteristics included a median age of 79 years and a predominance of women with breast cancer. About one half of patients had an independent functional status. Nearly 15% presented severe undernourishment. Depression was suspected in 53.1% of cases. One third of these patients had>2 chronic diseases, and 74% of patients took>or =3 medications. Of the 93 patients with an initial treatment decision, the treatment plan was modified for 38.7% of cases after this assessment. Only body mass index and the absence of depressive symptoms were associated with a modification of the treatment plan.\nQuestion: Does a geriatric oncology consultation modify the cancer treatment plan for elderly patients?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} +{"context": "Context: BACKGROUND. The alterations of echocardiography and electrocardiogram (ECG) in patients received left atrial appendage LAA occlusion therapy are still unclear. The present study was to evaluate the influence of LAA occlusion device on echocardiography and ECG changes in patients with atrial fibrillation (AF).\nMETHODS. Seventy-three patients who had undergone Watchman, LAmbre and Lefort were enrolled in this study. Echocardiography and ECG results at pre- and post-operation were collected. Besides, echocardiography was also performed during follow-up visits at 1, 6 and 12months after discharge.\nRESULTS. After LAA occlusion, a slight and measureable movement of QRS electric axis was observed in most patients. The significant differences were also observed in heart rate (HR) and the mean-mean QT interval between pre- and post-operation for all patients. There existed no significant difference in echocardiographic parameters between before and after device implantation. However, a larger left atrial (LA) diameter was detected by echocardiography during follow-up visit at 6months when compared with pre-operation parameters. Similarly, aortic root diameter (ARD) was also larger during follow-up at 12months than the baseline dimension in pre-operation.\nQuestion: Does left atrial appendage (LAA) occlusion device alter the echocardiography and electrocardiogram parameters in patients with atrial fibrillation?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} +{"context": "Context: BACKGROUND. Currently the choice of breast cancer therapy is based on prognostic factors. The proliferation marker Ki-67 is used increasingly to determine the method of therapy. The current study analyses the predictive value of Ki-67 in foreseeing breast cancer patients' responses to neoadjuvant chemotherapy.\nMETHODS. This study includes patients with invasive breast cancer treated between 2008 and 2013. The clinical response was assessed by correlating Ki-67 to histological examination, mammography, and ultrasonography findings.\nRESULTS. The average Ki-67 value in our patients collectively (n = 77) is 34.9 ± 24.6%. The average Ki-67 value is the highest with 37.4 ± 24.0% in patients with a pCR. The Ki-67 values do not differ significantly among the 3 groups: pCR versus partial pathological response versus stable disease/progress (P = 0.896). However, Ki-67 values of patients with luminal, Her2 enriched, and basal-like cancers differed significantly from each other. Furthermore, within the group of luminal tumors Ki-67 values of patients with versus without pCR also differed significantly.\nQuestion: Can ki-67 play a role in prediction of breast cancer patients' response to neoadjuvant chemotherapy?\nA. yes\nB. no\nC. maybe\nAnswer: ", "continuation": "yes"} diff --git a/tests/eval/local_data/triviaqa_small.jsonl b/tests/eval/local_data/triviaqa_small.jsonl new file mode 100644 index 0000000000..ae5e0783d9 --- /dev/null +++ b/tests/eval/local_data/triviaqa_small.jsonl @@ -0,0 +1,4 @@ +{"context": "Who was the man behind The Chipmunks?", "answer": "David Seville", "aliases": ["David Seville"]} +{"context": "What star sign is Jamie Lee Curtis?", "answer": "Scorpio", "aliases": ["Scorpio", "Skorpio"]} +{"context": "Which Lloyd Webber musical premiered in the US on 10th December 1993?", "answer": "Sunset Boulevard", "aliases": ["Sunset Blvd", "Sunset Boulevard", "Sunset Bulevard", "West Sunset Boulevard"]} +{"context": "Who was the next British Prime Minister after Arthur Balfour?", "answer": "Campbell-Bannerman", "aliases": ["Campbell Bannerman", "Campbell-Bannerman", "Henry Campbell Bannerman", "Henry Campbell-Bannerman", "Sir Henry Campbell Bannerman", "Sir Henry Campbell-Bannerman"]} diff --git a/tests/eval/local_data/winograd_small.jsonl b/tests/eval/local_data/winograd_small.jsonl new file mode 100644 index 0000000000..8f84cd27e5 --- /dev/null +++ b/tests/eval/local_data/winograd_small.jsonl @@ -0,0 +1,4 @@ +{"context_options": ["The city councilmen refused the demonstrators a permit because the city councilmen", "The city councilmen refused the demonstrators a permit because the demonstrators"], "continuation": "feared violence.", "gold": 0} +{"context_options": ["The city councilmen refused the demonstrators a permit because the city councilmen", "The city councilmen refused the demonstrators a permit because the demonstrators"], "continuation": "advocated violence.", "gold": 1} +{"context_options": ["The trophy doesn't fit into the brown suitcase because the trophy", "The trophy doesn't fit into the brown suitcase because the suitcase"], "continuation": "is too large.", "gold": 0} +{"context_options": ["The trophy doesn't fit into the brown suitcase because the trophy", "The trophy doesn't fit into the brown suitcase because the suitcase"], "continuation": "is too small.", "gold": 1} diff --git a/tests/eval/test_in_context_learning_datasets.py b/tests/eval/test_in_context_learning_datasets.py new file mode 100644 index 0000000000..33a041aaea --- /dev/null +++ b/tests/eval/test_in_context_learning_datasets.py @@ -0,0 +1,2841 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import os +import random +import types +from pathlib import Path +from typing import Dict, List, Optional + +import pytest +import torch +from composer import Evaluator +from composer.core import DataSpec +from torch.utils.data import DataLoader + +# isort: off +from llmfoundry.eval.datasets import ( + InContextLearningDataset, InContextLearningCodeEvalDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningSchemaTaskDataset, get_icl_task_dataloader, strip_data, + tokenizer_needs_prefix_space, trim_context, get_continuation_span, + get_fewshot_sample_idxs, make_padded_input) +# isort: on +import transformers +from composer.datasets.utils import MultiTokenEOSCriteria +from composer.loggers import InMemoryLogger +from composer.models import HuggingFaceModel +from composer.trainer import Trainer +from composer.utils import dist, reproducibility + +from llmfoundry.eval.metrics import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningMultipleChoiceAccuracy) + + +def test_strip_data(): + data_to_strip = { + 'strip_data': ' boo! \n', + 'has_space': ' wa hoo!', + 'end_space': 'yoohoo! ' + } + stripped_data = strip_data(data_to_strip) + for k, v in stripped_data.items(): + assert k in data_to_strip + assert not v[0].isspace() + assert not v[-1].isspace() + + +@pytest.mark.skip( + reason="Currently don't have a tokenizer that satisfies this test") +def test_tokenizer_needs_prefix_space_when_space_not_needed( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + assert not tokenizer_needs_prefix_space(tiny_gpt2_tokenizer) + + +def test_tokenizer_needs_prefix_space_when_space_needed(): + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m', + use_fast=False) # type: ignore reportUnboundVariable + assert tokenizer_needs_prefix_space(tokenizer) + + +def test_trim_context(): + context = [0] * 99 + [1] * 2037 + continuation = [2] * 10 + max_seq_len = 2048 + trimmed_context = trim_context(context, + continuation, + max_seq_len=max_seq_len) + assert len(trimmed_context) == 2038 + assert trimmed_context[0] == 0 + assert trimmed_context[1] == 1 + + +def test_trim_context_no_continuation(): + context = [0] * 2048 + max_seq_len = 2048 + trimmed_context = trim_context(context, [], max_seq_len=max_seq_len) + assert len(trimmed_context) == 2048 + context = [0] * 3000 + [1] + max_seq_len = 2048 + trimmed_context = trim_context(context, [], max_seq_len=max_seq_len) + assert len(trimmed_context) == 2048 + assert trimmed_context[-1] == 1 + + +def test_get_continuation_span(): + context = [0] * 200 + continuation = [1] * 3 + cont_span = get_continuation_span(context, continuation) + assert torch.all(torch.eq(cont_span, torch.tensor([200, 201, 202]))) + continuation = [1] + cont_span = get_continuation_span(context, continuation) + assert torch.all(torch.eq(cont_span, torch.tensor([200]))) + + +@pytest.mark.parametrize('padding_side', ['left', 'right', 'middle']) +def test_make_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + padding_side: str): + context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] + padding_id = tiny_gpt2_tokenizer.eos_token_id + + error_context = contextlib.nullcontext() if padding_side in { + 'left', 'right' + } else pytest.raises(ValueError) + + with error_context: + input_ids = make_padded_input(context, [], + 2048, + padding_id, + padding_side=padding_side) + + if padding_side == 'left': + assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id + assert input_ids[48:].tolist() == context + elif padding_side == 'right': + assert input_ids[-1] == tiny_gpt2_tokenizer.eos_token_id + assert input_ids[:-48].tolist() == context + + +def test_batch_padding_logic_no_padding( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids'] + context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] + max_seq_len = 2048 + trimmed_context = trim_context(context, continuation, max_seq_len) + continuation_spans = get_continuation_span(trimmed_context, continuation) + padded_input = make_padded_input(trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right') + assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047 + assert len(padded_input) == 2048 + assert tiny_gpt2_tokenizer.pad_token_id not in padded_input + + +def test_batch_padding_logic_with_padding( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + continuation = tiny_gpt2_tokenizer(' dog' * 200)['input_ids'] + context = tiny_gpt2_tokenizer(' cat' * 200)['input_ids'] + max_seq_len = 2048 + trimmed_context = trim_context(context, continuation, max_seq_len) + continuation_spans = get_continuation_span(trimmed_context, continuation) + padded_input = make_padded_input(trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right') + assert continuation_spans[0] == 200 and continuation_spans[-1] == 399 + assert len(padded_input) == 2048 + assert padded_input[-1] == tiny_gpt2_tokenizer.pad_token_id + + +def test_fewshot_sample_idxs(): + rng = random.Random(1234) + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, + num_fewshot=4, + example_idx=4, + rng=rng) + assert fewshot_idxs == {0, 1, 2, 3} + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, + num_fewshot=5, + example_idx=4, + rng=rng) + assert fewshot_idxs == {0, 1, 2, 3} + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, + num_fewshot=500, + example_idx=4, + rng=rng) + assert fewshot_idxs == {0, 1, 2, 3} + + fewshot_idxs = get_fewshot_sample_idxs(dataset_size=10, + num_fewshot=7, + example_idx=4, + rng=rng) + assert len(fewshot_idxs) == 7 and 4 not in fewshot_idxs + + +def test_fewshot_sample_idxs_randomness(): + dataset_size = 10000 + num_fewshot = 5 + + rng_1_seed_1234 = random.Random(1234) + rng_2_seed_1234 = random.Random(1234) + rng_3_seed_11 = random.Random(11) + + rng_1_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, + rng_1_seed_1234) + rng_2_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, + rng_2_seed_1234) + rng_3_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, + rng_3_seed_11) + + assert rng_1_sample_1 == rng_2_sample_1 + assert rng_1_sample_1 != rng_3_sample_1 + + rng_1_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, + rng_1_seed_1234) + rng_2_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, + rng_2_seed_1234) + rng_3_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, + rng_3_seed_11) + + assert rng_1_sample_2 == rng_2_sample_2 + assert rng_1_sample_2 != rng_3_sample_2 + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_update_generation_kwargs( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + gen_kwargs = {'test_arg1': 1, 'test_arg2': 2} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=gen_kwargs) + assert dl.base_batch['generation_kwargs'] == { + 'test_arg1': 1, + 'test_arg2': 2 + } + + +def test_stop_sequences_criteria( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) + seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids'] + seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + seq1 = [tiny_gpt2_tokenizer.pad_token_id] * (len(seq2) - len(seq1)) + seq1 + input_ids = torch.LongTensor([seq1, seq2]) + assert not eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) + seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + input_ids = torch.LongTensor([seq1, seq2]) + assert eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + +def test_stop_sequences_criteria_sentencepiece( + tiny_llama_tokenizer: transformers.AutoTokenizer): + + tokenizer = tiny_llama_tokenizer + eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) + seq1 = tokenizer( + '\n\nDogs' + )['input_ids'] # check to make sure starting with the stop sequence doesnt break it + seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] + seq1 = [tokenizer.eos_token_id] * (len(seq2) - len(seq1)) + seq1 + input_ids = torch.LongTensor([seq1, seq2]) + assert not eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) + seq1 = tokenizer('Dogs are furry\n\n')['input_ids'] + seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] + input_ids = torch.LongTensor([seq1, seq2]) + assert eos_criteria(input_ids, + None) # pyright: ignore[reportGeneralTypeIssues] + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_update_generation_kwargs_no_kwargs( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert not 'generation_kwargs' in dl.base_batch + + +def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generation_kwargs=None) + assert len(dl.base_batch['generation_kwargs']) == 4 + + +def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generation_kwargs={'temperature': 0.9}) + assert 'generation_kwargs' in dl.base_batch + assert dl.base_batch['generation_kwargs']['temperature'] == 0.9 + assert len(dl.base_batch['generation_kwargs']) == 5 + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_construct_context(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + constructed_context = dl.construct_context({ + 'context': 'quas quas exort', + 'answer': 'ice wall' + }) + assert constructed_context == 'Orbs: quas quas exort\nSpell: ' + constructed_context = dl.construct_context( + { + 'context': 'quas quas exort', + 'answer': 'ice wall' + }, add_answer=True) + assert constructed_context == 'Orbs: quas quas exort\nSpell: ice wall' + constructed_context = dl.construct_context( + { + 'context': 'quas quas exort', + 'answer': 'ice wall' + }, + preceding_text='The harsh White Waste beckons!', + add_answer=True) + assert constructed_context == '\nOrbs: quas quas exort\nSpell: ice wall' + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_get_answer_from_example( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + answer = dl.get_answer_from_example({ + 'context': 'wex exort exort', + 'answer': 'alacrity' + }) + assert answer == ' alacrity' + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_fix_eos_on_preamble(tmp_path: Path): + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m', + use_fast=False) # type: ignore reportUnboundVariable + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + preamble = 'blah blah blah.' + tokenized_preamble = tokenizer.encode(preamble) + tokenized_preamble += [tokenizer.eos_token_id] + fixed_preamble = dl._fix_eos_on_preamble(tokenized_preamble) + assert tokenized_preamble[:-1] == fixed_preamble + assert fixed_preamble[-1] != tokenizer.eos_token_id + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_tokenize_example_with_tokenize_labels( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + tokenize_labels=True) + tokenized_example = dl.tokenize_example('What spell does this invoke? ', + 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}) + tokenized_input = [ + 2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, + 31221, 25, 19145, 1894 + ] + assert tokenized_example['context'][:len(tokenized_input)].tolist( + ) == tokenized_input + assert tokenized_example['context'][-1] == tokenizer.eos_token_id + assert type(tokenized_example['answer'][0]) == int + assert len(tokenized_example['context']) == seqlen + assert 'continuation_indices' in tokenized_example + + +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_tokenize_example_with_no_tokenize_labels( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + hf_loading_vars = { + 'split': 'test', + 'name': 'invoker', + } + hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']} + + dl = InContextLearningDataset( + dataset_uri='hf://mosaicml/test_dataset', + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Orbs: ', + continuation_delimiter='\nSpell: ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + tokenize_labels=False) + tokenized_example = dl.tokenize_example('What spell does this invoke? ', + 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}) + tokenized_input = [ + 2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, + 31221, 25 + ] + assert tokenized_example['context'][:len(tokenized_input)].tolist( + ) == tokenized_input + assert tokenized_example['context'][-1] == tokenizer.eos_token_id + assert len(tokenized_example['context']) == seqlen + assert type(tokenized_example['answer']) == str + + +def test_qa_set_cot_no_cot(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + assert not dl.has_cot + + +def test_qa_set_cot_has_cot(tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/gsm8k_small.jsonl' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + assert dl.has_cot + + +def test_qa_get_max_answer_length( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='', + continuation_delimiter='', + cot_delimiter='', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + # empirical number from the small test dataset + assert dl.max_answer_length == 7 + + +def test_qa_get_answer_from_example_with_no_cot( + tmp_path: Path, tiny_gpt2_tokenizer: transformers.AutoTokenizer): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + answer = dl.get_answer_from_example({ + 'context': 'empty', + 'answer': 'this is the correct answer', + 'chain_of_thought': "Let's think step by step. " + }) + assert answer == 'this is the correct answer' + + +def test_qa_get_answer_from_example_with_cot( + tmp_path: Path, tiny_gpt2_tokenizer: transformers.AutoTokenizer): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + dl.has_cot = True + answer = dl.get_answer_from_example({ + 'context': 'empty', + 'answer': 'this is the correct answer', + 'chain_of_thought': "Let's think step by step. " + }) + assert answer == "Let's think step by step. ### this is the correct answer" + + +def test_qa_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningGenerationTaskWithAnswersDataset( + dataset_uri=dataset_uri, + tokenizer=tiny_gpt2_tokenizer, + max_seq_len=1024, + pad_tok_id=tiny_gpt2_tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + cot_delimiter=' ### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + dl.has_cot = True + tokenized_example = dl.tokenize_example( + 'starting prompt', 'a context', { + 'context': 'empty', + 'answer': 'this is the correct answer', + 'aliases': ['this is the right answer', 'this is the best answer'], + 'chain_of_thought': "Let's think step by step. " + }) + assert 'aliases' in tokenized_example + assert tokenized_example['aliases'] == [ + 'this is the right answer', 'this is the best answer' + ] + + +def test_code_adjust_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/human_eval_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000} + + dl = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Code start:', + continuation_delimiter='\nPlease code:', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + generation_kwargs=gen_kwargs, + generations_per_sample=10, + ) + + assert all( + len(data['prompt']) == 148 + for data in dl.dataset) # pyright: ignore [reportGeneralTypeIssues] + + +def test_code_update_gen_kwargs(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/human_eval_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000} + + dl = InContextLearningCodeEvalDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + prelimiter='Code start:', + continuation_delimiter='\nPlease code:', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + generation_kwargs=gen_kwargs, + generations_per_sample=10, + ) + assert dl.base_batch['generation_kwargs']['num_beams'] == 9000 + assert dl.base_batch['generation_kwargs']['top_p'] == .95 + assert dl.base_batch['generation_kwargs']['temperature'] == .9 + assert dl.base_batch['generation_kwargs']['do_sample'] == True + + +def test_mc_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/mmlu_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningMultipleChoiceTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context': + "Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: ", + 'choices': ['A', 'B', 'C', 'D'], + 'gold': + 2 + } + tokenized_example = dl.tokenize_example( + prompt_and_fewshot='Answer the following: ', + ctxt=example['context'], + example=example) + unpadded_queries = [ + context[context != tokenizer.eos_token_id] + for context in tokenized_example['query'] + ] + untokenized_inputs = [ + tokenizer.decode(unpadded_input) for unpadded_input in unpadded_queries + ] + correct_output = [ + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: A", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: B", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: C", + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: D" + ] + assert untokenized_inputs == correct_output + + +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_schema_construct_context( + prelimiter: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string='', + prelimiter=prelimiter, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context_options': ['cont one', 'cont two'], + 'gold': 0, + 'continuation': 'this is a continuation' + } + constructed_context = dl.construct_context(example) + assert constructed_context == f'{prelimiter}cont one ### this is a continuation' + constructed_context = dl.construct_context(example, preceding_text='text') + assert constructed_context == f'{prelimiter}\ncont one ### this is a continuation' + + +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_schema_construct_multiple_contexts( + prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prelimiter=prelimiter, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter=' ### ', + destination_path=str(tmp_path / 'test_human_eval_small.jsonl'), + ) + example = { + 'context_options': [f'cont one', 'cont two'], + 'gold': 0, + 'continuation': 'this is a continuation' + } + constructed_contexts = dl._construct_multiple_contexts(example) + assert constructed_contexts == [ + f'{prelimiter}cont one', f'{prelimiter}cont two' + ] + constructed_contexts = dl._construct_multiple_contexts( + example, preceding_text='some text') + assert constructed_contexts == [ + f'{prelimiter}\ncont one ###', f'{prelimiter}\ncont two ###' + ] + + +def test_schema_tokenize_example( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/winograd_small.jsonl' + tokenizer = tiny_gpt2_tokenizer + seqlen = 2048 + num_fewshot = 0 + prompt_string = '' + seqlen = 2048 + dl = InContextLearningSchemaTaskDataset( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + fewshot_random_seed=1, + prompt_string=prompt_string, # pyright: ignore + example_delimiter='\n', # pyright: ignore + continuation_delimiter=' ### ', + destination_path=str(tmp_path / + 'test_human_eval_small.jsonl'), # pyright: ignore + ) + example = { + 'context_options': ['context one', 'context two'], + 'gold': 0, + 'continuation': 'this is a continuation' + } + tokenized_example = dl.tokenize_example( + prompt_and_fewshot='prompt ', + context_options=example['context_options'], + example=example) + assert all( + tiny_gpt2_tokenizer.decode(cont) == ' this is a continuation' + for cont in tokenized_example['answer']) + unpadded_inputs = [ + context[context != tokenizer.eos_token_id] + for context in tokenized_example['context_options'] + ] + untokenized_inputs = [ + tokenizer.decode(unpadded_input) for unpadded_input in unpadded_inputs + ] + assert untokenized_inputs == [ + 'prompt context one this is a continuation', + 'prompt context two this is a continuation' + ] + + +@pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl']) +def test_mc_task_dataloader_subcategories( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 8 + seqlen = 64 + dls = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=2, + prompt_string= + 'The following are multiple choice questions (with answers).\n', + example_delimiter='\n', + continuation_delimiter='Answer: ', + destination_path=str(tmp_path / 'icl.jsonl'), + has_categories=True) + assert isinstance(dls, dict) + + assert 'computer_security' in dls + dl = dls['computer_security'] + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + assert dl.dataloader.__len__() == 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ' A' + + +@pytest.mark.parametrize('dataset_uri', [ + 'pubmed_sm.jsonl', +]) +def test_lm_task_dataloader_extra_space( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=10, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert ' ' not in tokenizer.decode(batch['input_ids'][0][0:max_idx + 1]) + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' yes' + + +@pytest.mark.parametrize('dataset_uri', [ + 'lambada_small.jsonl', +]) +def test_lm_task_dataloader(dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' glen' + + +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_schema_task_dataloader(dataset_uri: str, prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + question_prelimiter=prelimiter, + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' feared violence.' + + +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) +def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri: str, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b', # type: ignore reportUnboundVariable + use_fast=False) + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + dl = get_icl_task_dataloader('schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode( + batch['input_ids'][0][0:max_idx + 1] + ) == "The trophy doesn't fit into the brown suitcase because the suitcase is too small. \nThe city councilmen refused the demonstrators a permit because the city councilmen feared violence." + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_lm_task_dataloader_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri: str, + num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_opt_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 512 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' glen' + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_mc_task_dataloader_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri: str, + num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_opt_tokenizer + + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 64 + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert dl.get_num_samples_in_batch(batch) == 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_mc_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_opt_tokenizer + + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + choices_per_question = 2 + real_microbatch_size = batch_size // 2 + logical_microbatch_size = real_microbatch_size // choices_per_question + microbatches = dl.split_batch(batch, logical_microbatch_size) + assert len(microbatches) == 2 + for i, microbatch in enumerate(microbatches): + assert dl.get_num_samples_in_batch(microbatch) == 1 + assert 'input_ids' in microbatch + assert tuple(microbatch['input_ids'].shape) == (real_microbatch_size, + seqlen) + assert 'attention_mask' in microbatch + assert tuple( + microbatch['attention_mask'].shape) == (real_microbatch_size, + seqlen) + assert 'continuation_indices' in microbatch + assert isinstance(microbatch['continuation_indices'], list) and len( + microbatch['continuation_indices']) == real_microbatch_size + assert 'mode' in microbatch + assert microbatch['mode'] == 'icl_task' + assert 'gold_indices' in microbatch + assert isinstance(microbatch['gold_indices'], list) and len( + microbatch['gold_indices'] + ) == real_microbatch_size // choices_per_question + assert 'choice_groupings' in microbatch + assert isinstance(microbatch['choice_groupings'], list) and len( + microbatch['choice_groupings'] + ) == real_microbatch_size // choices_per_question + + min_idx = min(microbatch['continuation_indices'][0]).item() + max_idx = max(microbatch['continuation_indices'][0]).item() + if i == 0: + assert tokenizer.decode( + microbatch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + elif i == 1: + assert tokenizer.decode( + microbatch['input_ids'][0][min_idx:max_idx + 1] + ) == ' Weld the metal together to get it to stay firmly in place' + assert tokenizer.decode( + microbatch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode( + microbatch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +def test_qa_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) # for dist + dl = get_icl_task_dataloader( + icl_task_type='generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=8, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + assert isinstance(dl, DataSpec) # pyright + + batch = next(iter(dl.dataloader)) + split_batch = dl.split_batch(batch, 3) + + assert len(split_batch) == 2 + split1 = split_batch[0] + split2 = split_batch[1] + + assert split1['input_ids'].shape[0] == 3 + assert split2['input_ids'].shape[0] == 1 + + assert split1['attention_mask'].shape[0] == 3 + assert split2['attention_mask'].shape[0] == 1 + + assert isinstance(split1['mode'], str) + assert isinstance(split2['mode'], str) + + assert len(split1['labels']) == 3 + assert len(split2['labels']) == 1 + assert all(isinstance(v, list) for v in split1['labels'] + split2['labels']) + + assert isinstance(split1['generation_kwargs']['max_new_tokens'], int) + assert isinstance(split2['generation_kwargs']['max_new_tokens'], int) + + assert isinstance(split1['generation_kwargs'], dict) + assert isinstance(split2['generation_kwargs'], dict) + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('prompt_string', ['I am a prompt', '']) +def test_qa_task_dataloader_w_null_eos( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, prompt_string: str): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + tiny_gpt2_tokenizer.eos_token_id = None + with pytest.raises(ValueError): + _ = get_icl_task_dataloader('generation_task_with_answers', + dataset_uri, + tokenizer, + batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter='\nA:', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl')) + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('prompt_string', ['I am a prompt', '']) +def test_qa_task_dataloader(dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, + prompt_string: str): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + # empirical number from the small test dataset + maximum_answer_length = 7 + dl = get_icl_task_dataloader('generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter='\nA:', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl')) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert tuple(batch['input_ids'].shape) == (batch_size, + seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - + maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch) + assert all(item.count('\nA:') == num_fewshot + 1 for item in decoded_batch) + + if len(prompt_string) > 0: + assert all(item.count('I am a prompt') == 1 for item in decoded_batch) + assert all( + set(found) == set(expected) for found, expected in zip( + batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']])) + assert decoded_batch[0].endswith( + 'Q: Who was the man behind The Chipmunks?\nA:') + assert decoded_batch[1].endswith( + 'Q: What star sign is Jamie Lee Curtis?\nA:') + assert 'eos_token_id' in batch['generation_kwargs'] + + +@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +def test_qa_task_with_cot_dataloader( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 512 + # empirical number from the small test dataset + maximum_answer_length = 132 + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter="\nA: Let's think step by step. ", + cot_delimiter=' #### ', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + assert tuple(batch['input_ids'].shape) == (batch_size, + seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - + maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch) + assert all(item.count('\nA:') == num_fewshot + 1 for item in decoded_batch) + + assert batch['labels'] == [['18'], ['3']] + if num_fewshot == 0: + assert decoded_batch[0].endswith( + "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step." + ) + assert decoded_batch[1].endswith( + "Q: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step." + ) + elif num_fewshot == 2: + assert decoded_batch[0].endswith( + "Q: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\nA: Let's think step by step. He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters #### 540\nQ: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step." + ) + assert decoded_batch[1].endswith( + "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step. Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. #### 18\nQ: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step." + ) + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) +def test_mc_task_dataloader(dataset_uri: str, prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + example_delimiter = '\n' + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + question_prelimiter=prelimiter, + example_delimiter=example_delimiter, + continuation_delimiter='\nA: ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len( + batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + q1 = 'how do you shake something?\nA: ' + a1 = 'move it up and down and side to side quickly.' + q2 = "When boiling butter, when it's ready, you can\nA:" + assert tokenizer.decode( + batch['input_ids'][0][:min_idx] + ) == f'{prelimiter}{q1}{a1}{example_delimiter}{prelimiter}{q2}' + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' Pour it onto a plate' + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +def test_code_eval_split_batch(dataset_uri: str, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'EleutherAI/gpt-neox-20b') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=5, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=2, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=3, + ) + + assert isinstance(dl, DataSpec) # pyright + batches = list(dl.dataloader) + + for k in ('input_ids', 'attention_mask'): + assert [b[k].shape[0] for b in batches] == [5, 5, 2] + + list_keys = { + 'labels': str, + 'prompts': str, + 'tests': str, + 'entry_points': str, + 'test_inputs': list, + 'test_outputs': list, + 'languages': str, + } + + for batch, size in zip(batches, [5, 5, 2]): + for field, type_ in list_keys.items(): + assert len(batch[field]) == size + assert all(isinstance(val, type_) for val in batch[field]) + + static_keys = {'pass_at_k': (int, list), 'generation_kwargs': dict} + for batch in batches: + assert 'generation_kwargs' in batch + assert 'max_new_tokens' in batch['generation_kwargs'] + assert isinstance(batch['generation_kwargs']['max_new_tokens'], int) + for field, type_ in static_keys.items(): + assert isinstance(batch[field], type_) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('prompt_string', ['Please code:\n', '']) +@pytest.mark.parametrize('generations_per_sample', [1, 3]) +def test_code_eval_sentpiece_dataloader( + dataset_uri: str, tmp_path: Path, num_fewshot: int, prompt_string: str, + generations_per_sample: int, + tiny_llama_tokenizer: transformers.AutoTokenizer): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_llama_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 5 + seqlen = 2048 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=generations_per_sample) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batches = list(dl.dataloader) + dataset_size = len(open(dataset_uri, 'r').read().strip().split('\n')) + dataset_size *= generations_per_sample + + max_prompt_length = 0 + + has_left_padding = [] + for i, batch in enumerate(batches): + if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): + max_prompt_length = dl.dataloader.dataset.max_prompt_length + N = len(batches) + bs = batch_size if i < N - 1 else dataset_size - (N - 1) * batch_size + assert tuple(batch['input_ids'].shape) == (bs, max_prompt_length) + assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == 129 + has_left_padding.extend( + [item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) + assert not all(has_left_padding) # longest should be pushed left + + decoded_batches = [ + tokenizer.batch_decode(batch['input_ids']) for batch in batches + ] + for decoded_batch in decoded_batches: + assert all( + item.count('Code start: \n') == num_fewshot + 1 + for item in decoded_batch) + + if len(prompt_string) > 0: + assert all( + item.count('Please code:\n') == 1 for item in decoded_batch) + + labels = [ + ' for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n', + " result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n", + ' return number % 1.0\n', + ' balance = 0\n\n for op in operations:\n balance += op\n if balance < 0:\n return True\n\n return False\n', + ] + + # assert decoded_batch[0].endswith( + samples = [ + "Code start: \nfrom typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", + "Code start: \n\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n" + ] + for i in range(4): + for j in range(generations_per_sample): + k = i * generations_per_sample + j + b, n = divmod(k, batch_size) + assert batches[b]['labels'][n] == labels[i] + assert decoded_batches[b][n].endswith(samples[i]) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +def test_code_eval_test_cases(dataset_uri: str, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_.jsonl'), + generations_per_sample=1) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + max_prompt_length = 0 + if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): + max_prompt_length = dl.dataloader.dataset.max_prompt_length + assert tuple(batch['input_ids'].shape) == (batch_size, max_prompt_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, + max_prompt_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == 129 + assert any(item[0] != tokenizer.eos_token_id + for item in batch['input_ids']) # longest should be pushed left + + mod = types.ModuleType('test_module') + for prompt, solution, inputs, outputs, entry_point in zip( + batch['prompts'], batch['labels'], batch['test_inputs'], + batch['test_outputs'], batch['entry_points']): + exec(prompt + solution, mod.__dict__) + for test_input, test_output in zip(inputs, outputs): + result = mod.__dict__[entry_point](*eval(test_input)) + assert result == eval(test_output) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +def test_code_eval_pass_at_k_validity(dataset_uri: str, tmp_path: Path): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 64 + + with pytest.raises(ValueError, match=r'.* pass_at_k .*'): + get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_.jsonl'), + pass_at_k=10, + generations_per_sample=1) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('prompt_string', ['Please code:\n', '']) +@pytest.mark.parametrize('generations_per_sample', [1, 3]) +def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, + num_fewshot: int, prompt_string: str, + generations_per_sample: int): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mosaicml/mpt-7b') # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 2048 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=generations_per_sample, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40 + }) + assert isinstance(dl, DataSpec) + + assert isinstance(dl.dataloader, DataLoader) # pyright + batches = list(dl.dataloader) + dataset_size = len(open(dataset_uri, 'r').read().strip().split('\n')) + dataset_size *= generations_per_sample + + has_left_padding = [] + for i, batch in enumerate(batches): + max_prompt_length = 0 + if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): + max_prompt_length = dl.dataloader.dataset.max_prompt_length + N = len(batches) + bs = batch_size if i < N - 1 else dataset_size - (N - 1) * batch_size + assert tuple(batch['input_ids'].shape) == (bs, max_prompt_length) + assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == 122 + has_left_padding.extend( + [item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) + assert not all(has_left_padding) # longest should be pushed left + + decoded_batches = [ + tokenizer.batch_decode(batch['input_ids']) for batch in batches + ] + for decoded_batch in decoded_batches: + assert all( + item.count('Code start: \n') == num_fewshot + 1 + for item in decoded_batch) + + if len(prompt_string) > 0: + assert all( + item.count('Please code:\n') == 1 for item in decoded_batch) + + labels = [ + ' for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n', + " result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n", + ' return number % 1.0\n', + ' balance = 0\n\n for op in operations:\n balance += op\n if balance < 0:\n return True\n\n return False\n', + ] + + # assert decoded_batch[0].endswith( + samples = [ + "Code start: \nfrom typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", + "Code start: \n\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", + "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n" + ] + for i in range(4): + for j in range(generations_per_sample): + k = i * generations_per_sample + j + b, n = divmod(k, batch_size) + assert batches[b]['labels'][n] == labels[i] + assert decoded_batches[b][n].endswith(samples[i]) + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_eval_split_batch(mpt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, num_fewshot: int, tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + tokenizer = mpt_tokenizer # type: ignore reportUnboundVariable + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 4 + seqlen = 512 + + dl = get_icl_task_dataloader('code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str( + tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=1, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40 + }) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + microbatch_size = 1 + microbatches = dl.split_batch(batch, microbatch_size) + assert len(microbatches) == 4 + for microbatch in microbatches: + assert dl.get_num_samples_in_batch(microbatch) == 1 + assert 'input_ids' in microbatch + # TODO: what should this be? + # assert tuple(microbatch['input_ids'].shape) == (microbatch_size, seqlen) + assert 'attention_mask' in microbatch + # assert tuple(microbatch['attention_mask'].shape) == (microbatch_size, seqlen) + assert isinstance(microbatch['generation_kwargs'], dict) + assert microbatch['generation_kwargs']['temperature'] == .9 + assert microbatch['generation_kwargs']['top_k'] == 40 + assert microbatch['generation_kwargs']['pad_token_id'] == 0 + assert microbatch['generation_kwargs']['num_beams'] == 1 + assert microbatch['generation_kwargs']['do_sample'] == True + assert microbatch['generation_kwargs']['use_cache'] == True + assert microbatch['generation_kwargs']['eos_token_id'] == 0 + + +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +# @pytest.mark.gpu +# @pytest.mark.world_size(2) +def test_lm_task_evaluation(num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl'), + ) + + evaluator = Evaluator(label='lambada', + dataloader=dl, + metric_names=['InContextLearningLMAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningLMAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ep', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/lambada/InContextLearningLMAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data['metrics/lambada/InContextLearningLMAccuracy'][ + 0][1].item() == 0 + + +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +def test_schema_task_evaluation( + num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 8 + dl = get_icl_task_dataloader( + 'schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='winograd', + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningMultipleChoiceAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator) + assert 'metrics/winograd/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/winograd/InContextLearningMultipleChoiceAccuracy'][0][1].item( + ) > 0 + num_samples = 0 + with open(dataset_uri) as f: + for _ in f: + num_samples += 1 + assert trainer.state.eval_metrics['winograd'][ + 'InContextLearningMultipleChoiceAccuracy'].total == num_samples + + +@pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +def test_mc_task_evaluation_subcategories( + dataset_uri: str, num_fewshot: int, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 16 + max_seq_len = 64 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + reproducibility.seed_all(1234) + dls = get_icl_task_dataloader('multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str( + Path(gathered_paths[0]) / 'icl.jsonl'), + has_categories=True) + + assert isinstance(dls, dict) + evaluators = [ + Evaluator(label='mmlu/' + k, + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy']) + for k, dl in dls.items() + ] + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningMultipleChoiceAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluators) + assert 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy'][ + 0][1].item() >= 0 + total = trainer.state.eval_metrics['mmlu/computer_security'][ + 'InContextLearningMultipleChoiceAccuracy'].total + dist.all_reduce(total) # type: ignore + assert total.item() == 4 # type: ignore + + +@pytest.mark.parametrize('dataset_uri', + ['piqa_small.jsonl', 'hellaswag_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +@pytest.mark.gpu +@pytest.mark.world_size(2) +def test_mc_task_evaluation(num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 8 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + + # seed because the fewshot selection is currently unseeded + reproducibility.seed_all(1234) + dl = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=64, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='mc', + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningMultipleChoiceAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator) + assert 'metrics/mc/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/mc/InContextLearningMultipleChoiceAccuracy'][0][1].item() >= 0 + num_samples = 0 + with open(dataset_uri) as f: + for _ in f: + num_samples += 1 + total = trainer.state.eval_metrics['mc'][ + 'InContextLearningMultipleChoiceAccuracy'].total + dist.all_reduce(total) # type: ignore + assert total.item() == num_samples # type: ignore + + +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +@pytest.mark.gpu +@pytest.mark.world_size(2) +def test_qa_task_evaluation_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, + dataset_uri: str, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + + batch_size = 4 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='triviaqa', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +@pytest.mark.parametrize('num_fewshot', [5]) +@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') +def test_qa_task_evaluation_with_cot_opt_tokenizer( + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, + dataset_uri: str, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + + batch_size = 4 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter="A: Let's think step by step. ", + cot_delimiter=' #### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='gsm8k', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 5]) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_qa_task_evaluation(num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='triviaqa', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/triviaqa/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [5]) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +@pytest.mark.gpu +@pytest.mark.world_size(2) +def test_qa_task_with_cot_evaluation( + num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, tmp_path: Path): + + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter="A: Let's think step by step", + cot_delimiter=' #### ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + ) + + evaluator = Evaluator( + label='gsm8k', + dataloader=dl, + metric_names=['InContextLearningGenerationExactMatchAccuracy']) + + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningGenerationExactMatchAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + + trainer.eval(eval_dataloader=evaluator, subset_num_batches=2) + assert 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/gsm8k/InContextLearningGenerationExactMatchAccuracy'][0][ + 1].item() == 0 + + +def test_code_eval_requires_envvar(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('CODE_EVAL_DEVICE', raising=False) + with pytest.raises( + ValueError, + match='Attempting to use InContextLearningCodeEvalAccuracy but.*'): + InContextLearningCodeEvalAccuracy().get_client() + + +def test_code_eval_requires_valid_envvar(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv('CODE_EVAL_DEVICE', 'bigchungus') + with pytest.raises( + ValueError, + match='Environment variable `CODE_EVAL_DEVICE` must be on.*'): + InContextLearningCodeEvalAccuracy().get_client() + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('generations_per_sample', range(1, 3)) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_code_eval_microbatching( + monkeypatch: pytest.MonkeyPatch, + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, + dataset_uri: str, tmp_path: Path, generations_per_sample: int): + + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + batch_size = 4 + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=150, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=generations_per_sample, + ) + + evaluator = Evaluator(label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy'], + device_eval_microbatch_size=1) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tokenizer, + eval_metrics=[InContextLearningCodeEvalAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + torch.use_deterministic_algorithms(False) + trainer.eval(eval_dataloader=evaluator) + torch.use_deterministic_algorithms(True) + assert 'metrics/humaneval/InContextLearningCodeEvalAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/humaneval/InContextLearningCodeEvalAccuracy'][0][1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0]) +@pytest.mark.parametrize('generations_per_sample', range(1, 3)) +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_code_eval_sentpiece_evaluation( + monkeypatch: pytest.MonkeyPatch, num_fewshot: int, dataset_uri: str, + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, tmp_path: Path, + generations_per_sample: int): + + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_opt_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=175, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=generations_per_sample, + ) + + evaluator = Evaluator(label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy']) + model = HuggingFaceModel( + model=tiny_opt_model, + tokenizer=tiny_opt_tokenizer, + eval_metrics=[InContextLearningCodeEvalAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + torch.use_deterministic_algorithms(False) + trainer.eval(eval_dataloader=evaluator) + torch.use_deterministic_algorithms(True) + assert 'metrics/humaneval/InContextLearningCodeEvalAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/humaneval/InContextLearningCodeEvalAccuracy'][0][1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 2]) +@pytest.mark.parametrize('generations_per_sample', [1]) +@pytest.mark.filterwarnings(r'ignore: Input length of input_ids is') +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.filterwarnings( + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' +) +def test_code_eval_task_evaluation( + monkeypatch: pytest.MonkeyPatch, num_fewshot: int, dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, tmp_path: Path, + generations_per_sample: int): + + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + in_memory_logger = InMemoryLogger( + ) # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=64 * num_fewshot, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generations_per_sample=generations_per_sample, + ) + + evaluator = Evaluator(label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy']) + model = HuggingFaceModel( + model=tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + eval_metrics=[InContextLearningCodeEvalAccuracy()], + use_logits=True, + ) + + trainer = Trainer(model=model, max_duration='1ba', loggers=in_memory_logger) + torch.use_deterministic_algorithms(False) + trainer.eval(eval_dataloader=evaluator) + torch.use_deterministic_algorithms(True) + assert 'metrics/humaneval/InContextLearningCodeEvalAccuracy' in in_memory_logger.data.keys( + ) + assert in_memory_logger.data[ + 'metrics/humaneval/InContextLearningCodeEvalAccuracy'][0][1].item() == 0 + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +def test_lm_spacing_dataloader(dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path): + + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 512 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' UNIQUE ', + destination_path=str(tmp_path / 'icl.jsonl')) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + first_batch = next(dl.dataloader._get_iterator()) + second_batch = next(dl.dataloader._get_iterator()) + + first_batch_text = tokenizer.decode(first_batch['input_ids'][0], + skip_special_tokens=True) + second_batch_text = tokenizer.decode(second_batch['input_ids'][0], + skip_special_tokens=True) + + first_batch_without_last_word = ' '.join(first_batch_text.split(' ')[:-1]) + second_batch_without_last_word = ' '.join(second_batch_text.split(' ')[:-1]) + + assert first_batch_without_last_word.endswith(' UNIQUE') + assert second_batch_without_last_word.endswith(' UNIQUE') + + assert first_batch_without_last_word.count(' UNIQUE ') == 1 + assert second_batch_without_last_word.count(' UNIQUE ') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +@pytest.mark.parametrize('prompt_string', ['Complete the voiceline: ', '']) +@pytest.mark.parametrize('hf_loading_vars', [{ + 'split': 'test', + 'name': 'juggernaut', +}]) +@pytest.mark.parametrize( + 'hf_parsing_map', + [None, { + 'context': ['context'], + 'continuation': ['continuation'] + }]) +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_hf_dataloading_lm_dataloader( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, prompt_string: str, + hf_loading_vars: Dict[str, + str], hf_parsing_map: Optional[Dict[str, + List[str]]]): + + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + seqlen = 2048 + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len( + batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + + 1]) == ' and me.' + + decoded_batch = [ + tokenizer.decode(row[row != tokenizer.eos_token_id]) + for row in batch['input_ids'] + ] + assert decoded_batch[0] == "Looks like it's just you and me." + assert decoded_batch[ + 1] == "There's a fine line between bravery and stupidity." + + +@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +@pytest.mark.parametrize('prompt_string', ['What spell does this invoke? ', '']) +@pytest.mark.parametrize('hf_loading_vars', [{ + 'split': 'test', + 'name': 'invoker', +}]) +@pytest.mark.parametrize('hf_parsing_map', [{ + 'context': ['quas', 'wex', 'exort'], + 'answer': ['spell'] +}]) +@pytest.mark.filterwarnings( + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' +) +def test_hf_dataloading_custom_parsing( + dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, num_fewshot: int, prompt_string: str, + hf_loading_vars: Dict[str, str], hf_parsing_map: Dict[str, List[str]]): + + tokenizer = tiny_gpt2_tokenizer + batch_size = 2 + seqlen = 2048 + + # empirical number from the small test dataset + maximum_answer_length = 4 + + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Orbs: ', + continuation_delimiter='\nSpell:', + destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map) + assert isinstance(dl, DataSpec) + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert tuple(batch['input_ids'].shape) == (batch_size, + seqlen - maximum_answer_length) + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - + maximum_answer_length) + assert batch['mode'] == 'generate' + # the maximum generation length from the small test data + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length + assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) + + decoded_batch = tokenizer.batch_decode(batch['input_ids']) + assert all( + item.count('Orbs: ') == num_fewshot + 1 for item in decoded_batch) + assert all( + item.count('\nSpell:') == num_fewshot + 1 for item in decoded_batch) + + if len(prompt_string) > 0: + assert all( + item.count('What spell does this invoke? ') == 1 + for item in decoded_batch) + assert all( + set(found) == set(expected) for found, expected in zip( + batch['labels'], [['defeaning blast'], ['cold snap']])) + assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:') + assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:') diff --git a/tests/eval/test_nlp_metrics.py b/tests/eval/test_nlp_metrics.py new file mode 100644 index 0000000000..344d642715 --- /dev/null +++ b/tests/eval/test_nlp_metrics.py @@ -0,0 +1,196 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, List + +import pytest +import torch +import transformers + +from llmfoundry.eval.metrics import ( + InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningMultipleChoiceAccuracy) + + +def test_in_context_learning_lm_accuracy( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + contexts = ['The dog is', 'I love to eat', 'I hate', 'The weather is'] + continuations = [' furry', ' pie', ' long lines', ' snowy'] + pad = tiny_gpt2_tokenizer.pad_token_id + inputs = [ + tiny_gpt2_tokenizer(context)['input_ids'] + + tiny_gpt2_tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor( + [input + [pad] * (2048 - len(input)) for input in inputs]) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tiny_gpt2_tokenizer(context)['input_ids']) + end = start + len(tiny_gpt2_tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = { + 'continuation_indices': cont_idxs, + 'labels': inputs.roll(-1), + 'input_ids': inputs + } + logits = torch.nn.functional.one_hot(inputs.roll(-1), + num_classes=pad + 1).float() * 100 + start, end = cont_idxs[1].tolist()[0] - 1, cont_idxs[1].tolist()[-1] + logits[1][start:end] = logits[0][start:end].clone( + ) # make one of the answer's continuations incorrect + metric = InContextLearningLMAccuracy() + metric.update(batch, logits, batch['labels']) + + assert metric.compute() == 0.75 + + +def test_in_context_learning_qa_accuracy(): + outputs = [ + 'Correct but then some more text', 'Incorrect', + ' the CORREct with weird casing and spacing' + ] + labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct']] + batch = {'cot_delimiter': '', 'labels': labels} + metric = InContextLearningGenerationExactMatchAccuracy() + metric.update(batch, outputs, labels) + + assert metric.compute() == (2 / 3) + + +def test_in_context_learning_qa_cot_accuracy(): + outputs = [ + 'chain of thought ### Correct but then some more text\n\nanother chain of thought ### Incorrect answer this time', + 'Incorrect', + 'chain of thought ### the CORREct with weird casing and spacing', + 'incorrect chain of thought delimiter ## Correct but wrong delimiter' + ] + labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct'], ['correct']] + batch = { + 'cot_delimiter': ' ### ', + 'labels': labels, + 'do_normalization': True, + 'stopping_criteria': '\n\n' + } + metric = InContextLearningGenerationExactMatchAccuracy() + metric.update(batch, outputs, labels) + + assert metric.compute() == (2 / 4) + + +def test_in_context_learning_code_eval_accuracy( + monkeypatch: pytest.MonkeyPatch): + outputs = [ + ' return 1 if n <= 1 else fib(n - 1) + fib(n - 1)', # incorrect + ' if n <= 1:\n return 1\n return fib(n-1) + fib(n-2)', # incorrect spacing + ' return n * 2', # correct + ' return 2*n', # correct + ' return n + 2', # incorrect + ' return n + 1' + ] # correct + labels = [] + prompts = [ + 'def fib(n):\n', 'def multiply_by_two(n):\n', 'def add_one(n):\n' + ] + entry_points = ['fib', 'multiply_by_two', 'add_one'] + test_inputs = [['(1,)', '(2,)', '(4,)'], ['(1,)', '(2,)', '(4,)'], + ['(1,)', '(2,)', '(4,)']] + test_outputs = [['1', '2', '5'], ['2', '4', '8'], ['2', '3', '5']] + sample_ids = [0, 1, 2] + languages = ['python', 'python', 'python'] + monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') + generations_per_sample = 2 + + def repeat(values: List[Any]): + return [val for val in values for _ in range(generations_per_sample)] + + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mosaicml/mpt-7b') # type: ignore reportUnboundVariable + tokenizer.pad_token = tokenizer.eos_token + input_ids = tokenizer.batch_encode_plus(repeat(prompts), + return_tensors='pt', + padding=True)['input_ids'] + batch = { + # This tests deterministic beam search rather than sampling + 'input_ids': input_ids, + 'generation_kwargs': { + 'num_beams': 1, + }, + 'prompts': repeat(prompts), + 'pass_at_k': [1], + 'entry_points': repeat(entry_points), + 'test_inputs': repeat(test_inputs), + 'test_outputs': repeat(test_outputs), + 'languages': repeat(languages), + 'dataset_size': len(prompts), + 'generations_per_sample': generations_per_sample, + 'sample_id': repeat(sample_ids), + } + metric = InContextLearningCodeEvalAccuracy() + metric.update(batch, outputs, labels) + + # pass@1 values + # program 1: 0 + # program 2: 1 + # program 3: .5 + # mean: 0.5 + assert metric.compute() == 0.5 + + +def test_in_context_learning_mc_accuracy( + tiny_gpt2_tokenizer: transformers.AutoTokenizer): + contexts = [ + 'Q: How do you cook a cake?', 'Q: How do you cook a cake?', + 'Q: How old is the earth?', 'Q: How old is the earth?' + ] + continuations = [ + ' A: turn on the oven', ' A: do a backflip', ' A: 2 minutes', + ' A: 4.5 billion years' + ] + gold_indices = [0, 1] + choice_groupings = [(0, 2), (2, 4)] + pad = tiny_gpt2_tokenizer.pad_token_id + inputs = [ + tiny_gpt2_tokenizer(context)['input_ids'] + + tiny_gpt2_tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor( + [input + [pad] * (2048 - len(input)) for input in inputs]) + attention_mask = ~(inputs == pad) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tiny_gpt2_tokenizer(context)['input_ids']) + end = start + len(tiny_gpt2_tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = { + 'continuation_indices': cont_idxs, + 'labels': inputs.roll(-1), + 'input_ids': inputs, + 'attention_mask': attention_mask, + 'gold_indices': gold_indices, + 'choice_groupings': choice_groupings + } + logits = torch.nn.functional.one_hot(inputs.roll(-1), + num_classes=pad + 1).float() + + # for the first two, the correct answer is continuation 0 + # make the answer correct by making continuation 0 more likely for both answers + start, end = cont_idxs[1].tolist()[0] - 1, cont_idxs[1].tolist()[-1] + logits[1][start:end] = logits[0][start:end].clone() + + # for the last two, the correct answer is continuation 3 + # make the answer incorrect by making continuation 2 more likely for both answers + start, end = cont_idxs[3].tolist()[0], cont_idxs[3].tolist()[-1] + logits[3][start:end] = logits[2][start:end].clone() + + metric = InContextLearningMultipleChoiceAccuracy() + + metric.update(batch, logits, batch['labels']) + assert metric.compute() == 0.5 diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index ccbe1b69f7..16e3f8ad6f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -9,11 +9,19 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) +@pytest.fixture(autouse=True) +def save_registry_fixture(): + with save_registry(): + yield + + @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): """Initialize the default PyTorch distributed process group for tests.""" diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index e4e6892fe3..616d66085c 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -1,8 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy from typing import Any, Callable +import pytest from omegaconf import DictConfig from pytest import fixture from transformers import PreTrainedTokenizerBase @@ -71,3 +73,125 @@ def build(**kwargs: Any) -> ComposerHFCausalLM: return model return build + + +def tiny_gpt2_model_helper(config): # type: ignore + transformers = pytest.importorskip('transformers') + + return transformers.AutoModelForCausalLM.from_config(config) + + +@pytest.fixture(scope='session') +def _session_tiny_gpt2_model(_session_tiny_gpt2_config): # type: ignore + return tiny_gpt2_model_helper(_session_tiny_gpt2_config) + + +def tiny_gpt2_config_helper(): + transformers = pytest.importorskip('transformers') + + tiny_overrides = { + 'n_embd': 2, + 'n_head': 2, + 'n_layer': 2, + 'vocab_size': 50258 # 50257 + 1 for pad token + } + return transformers.AutoConfig.from_pretrained('gpt2', **tiny_overrides) + + +@pytest.fixture(scope='session') +def _session_tiny_gpt2_config(): # type: ignore + return tiny_gpt2_config_helper() + + +def tiny_gpt2_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') + hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + return hf_tokenizer + + +@pytest.fixture +def tiny_gpt2_model(_session_tiny_gpt2_model): # type: ignore + return copy.deepcopy(_session_tiny_gpt2_model) + + +@pytest.fixture(scope='session') +def _session_tiny_gpt2_tokenizer(): # type: ignore + return tiny_gpt2_tokenizer_helper() + + +@pytest.fixture +def tiny_gpt2_tokenizer(_session_tiny_gpt2_tokenizer): # type: ignore + return copy.deepcopy(_session_tiny_gpt2_tokenizer) + + +def tiny_llama_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained( + 'huggyllama/llama-7b', use_fast=False) + return hf_tokenizer + + +@pytest.fixture(scope='session') +def _session_tiny_llama_tokenizer(): # type: ignore + return tiny_llama_tokenizer_helper() + + +@pytest.fixture +def tiny_llama_tokenizer(_session_tiny_llama_tokenizer): # type: ignore + return copy.deepcopy(_session_tiny_llama_tokenizer) + + +def tiny_opt_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained( + 'facebook/opt-125m') + hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + return hf_tokenizer + + +def tiny_opt_model_helper(config): # type: ignore + transformers = pytest.importorskip('transformers') + + return transformers.AutoModelForCausalLM.from_config(config) + + +@pytest.fixture(scope='session') +def _session_tiny_opt_tokenizer(): # type: ignore + return tiny_opt_tokenizer_helper() + + +@pytest.fixture(scope='session') +def _session_tiny_opt_config(): # type: ignore + return tiny_opt_config_helper() + + +@pytest.fixture(scope='session') +def _session_tiny_opt_model(_session_tiny_opt_config): # type: ignore + return tiny_opt_model_helper(_session_tiny_opt_config) + + +def tiny_opt_config_helper(): + transformers = pytest.importorskip('transformers') + + tiny_overrides = { + 'n_embd': 2, + 'n_head': 2, + 'n_layer': 2, + 'vocab_size': 50272 + } + return transformers.AutoConfig.from_pretrained('facebook/opt-125m', + **tiny_overrides) + + +@pytest.fixture +def tiny_opt_tokenizer(_session_tiny_opt_tokenizer): # type: ignore + return copy.deepcopy(_session_tiny_opt_tokenizer) + + +@pytest.fixture +def tiny_opt_model(_session_tiny_opt_model): # type: ignore + return copy.deepcopy(_session_tiny_opt_model) diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 6e7838e7ba..712e515653 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -33,7 +33,7 @@ def test_fsdp_weight_tying(peft_config: Optional[dict], tmp_path: pathlib.Path, init_device: str): model_cfg = { 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1', + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', 'config_overrides': { 'num_hidden_layers': 2, 'hidden_size': 32, @@ -43,7 +43,7 @@ def test_fsdp_weight_tying(peft_config: Optional[dict], tmp_path: pathlib.Path, 'pretrained': False, 'init_device': init_device, } - tokenizer_name = 'mistralai/Mistral-7B-v0.1' + tokenizer_name = 'codellama/CodeLlama-7b-hf' assert model_cfg is not None assert tokenizer_name is not None diff --git a/tests/models/hf/test_hf_fsdp.py b/tests/models/hf/test_hf_fsdp.py deleted file mode 100644 index 0f49a4d43b..0000000000 --- a/tests/models/hf/test_hf_fsdp.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -from composer.models.huggingface import maybe_get_underlying_model -from omegaconf import DictConfig - -from llmfoundry.models.hf import ComposerHFCausalLM - - -def test_olmo_wraps(): - conf: dict = { - 'model': { - 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': 'allenai/OLMo-7B', - 'pretrained': False, - 'config_overrides': { - 'n_layers': 2, - } - }, - } - - config = DictConfig(conf) - - model = ComposerHFCausalLM(config.model, None) - - # check that all the modules we except are blocked from FSDP wrapping - underlying_model = maybe_get_underlying_model(model.model) - assert not underlying_model.model._fsdp_wrap - assert not underlying_model.model.transformer._fsdp_wrap - assert not underlying_model.model.transformer.wte._fsdp_wrap - assert not underlying_model.model.transformer.ff_out._fsdp_wrap diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index d8bea33dd4..7fe886ffe3 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -17,13 +17,15 @@ def test_peft_wraps(): - mistral_cfg = transformers.AutoConfig.from_pretrained( - 'mistralai/Mistral-7B-v0.1', num_hidden_layers=2) - mistral = transformers.AutoModelForCausalLM.from_config(mistral_cfg) - mistral = get_peft_model(mistral, LoraConfig()) - prepare_hf_model_for_fsdp(mistral, 'cpu') + mpt_cfg = transformers.AutoConfig.from_pretrained('mosaicml/mpt-7b', + n_layers=2, + trust_remote_code=True) + mpt = transformers.AutoModelForCausalLM.from_config(mpt_cfg, + trust_remote_code=True) + mpt = get_peft_model(mpt, LoraConfig()) + prepare_hf_model_for_fsdp(mpt, 'cpu') - for n, m in mistral.named_modules(): + for n, m in mpt.named_modules(): if 'lora' in n and 'default' in n: has_parameters = any(True for _ in m.parameters()) has_buffers = any(True for _ in m.buffers()) @@ -51,7 +53,7 @@ def test_lora_mixed_init(peft_config: Optional[dict], tmp_path: pathlib.Path, init_device: str): model_cfg = { 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1', + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', 'config_overrides': { 'num_hidden_layers': 2, 'hidden_size': 32, @@ -60,7 +62,7 @@ def test_lora_mixed_init(peft_config: Optional[dict], tmp_path: pathlib.Path, 'pretrained': False, 'init_device': init_device, } - tokenizer_name = 'mistralai/Mistral-7B-v0.1' + tokenizer_name = 'codellama/CodeLlama-7b-hf' assert model_cfg is not None assert tokenizer_name is not None diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..c8e7ec3e67 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -239,6 +239,10 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): torch_dmoe_config = copy.deepcopy(mb_dmoe_config) torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe' + del torch_dmoe_config.ffn_config['moe_world_size'] + del torch_dmoe_config.ffn_config['fc_type'] + del torch_dmoe_config.ffn_config['moe_loss_weight'] + del torch_dmoe_config.ffn_config['return_bias'] mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, dtype=dtype) diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 1e8ec2383d..08891d5199 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib -import os import pytest from composer.core.precision import get_precision_context @@ -15,53 +14,29 @@ @pytest.mark.gpu @pytest.mark.world_size(2) -@pytest.mark.parametrize('model_name', ['llama2', 'mistral']) +@pytest.mark.parametrize('model_name', ['codellama']) @pytest.mark.parametrize('use_flash_attention_2', [True, False]) @pytest.mark.parametrize('init_device', ['cpu', 'mixed', 'meta']) def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): - if model_name == 'llama2': - if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: - pytest.skip( - 'The CI cluster does not have access to the Llama models, so skip this test.' - ) + if model_name == 'codellama': model_cfg = { 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf', + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', 'config_overrides': { 'num_hidden_layers': 2, 'intermediate_size': 64, 'hidden_size': 64, }, - 'use_auth_token': True, 'pretrained': False, 'init_device': init_device, } - tokenizer_name = 'meta-llama/Llama-2-7b-hf' + tokenizer_name = 'codellama/CodeLlama-7b-hf' from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaFlashAttention2) flash_attn_class = LlamaFlashAttention2 if use_flash_attention_2 else LlamaAttention attention_layers_attr = 'model.model.layers' attention_attr = 'self_attn' - elif model_name == 'mistral': - model_cfg = { - 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1', - 'config_overrides': { - 'num_hidden_layers': 2, - 'intermediate_size': 64, - 'hidden_size': 64, - }, - 'pretrained': False, - 'init_device': 'cpu', - } - - tokenizer_name = 'mistralai/Mistral-7B-v0.1' - from transformers.models.mistral.modeling_mistral import ( - MistralAttention, MistralFlashAttention2) - flash_attn_class = MistralFlashAttention2 if use_flash_attention_2 else MistralAttention - attention_layers_attr = 'model.model.layers' - attention_attr = 'self_attn' else: raise ValueError(f'Unknown model: {model_name}') diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 6be2c5ca42..0efc245602 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -12,7 +12,8 @@ from omegaconf import OmegaConf as om from torch import nn -from llmfoundry.models.utils import MODEL_INIT_REGISTRY, generic_param_init_fn_ +from llmfoundry.layers_registry import param_init_fns +from llmfoundry.models.utils import generic_param_init_fn_ class MLP(nn.Module): @@ -150,7 +151,7 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): bias=True)), ])) - model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) + model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) assert isinstance(model.emb, torch.nn.Embedding) diff --git a/tests/test_registry.py b/tests/test_registry.py index 29d8e137f3..d7a1fc7dfe 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,11 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'param_init_fns', + 'module_init_fns', + 'ffns', + 'ffns_with_norm', + 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', 'fcs',