diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 582c6467e0..eef3877758 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -17,6 +17,11 @@ ) from e __all__ = [ - 'FDiffMetrics', 'Generate', 'MonolithicCheckpointSaver', 'GlobalLRScaling', - 'LayerFreezing', 'ScheduledGarbageCollector', 'ModelGauntlet' + 'FDiffMetrics', + 'Generate', + 'MonolithicCheckpointSaver', + 'GlobalLRScaling', + 'LayerFreezing', + 'ScheduledGarbageCollector', + 'ModelGauntlet', ] diff --git a/llmfoundry/callbacks/model_gauntlet_callback.py b/llmfoundry/callbacks/model_gauntlet_callback.py index f36f1c9871..1ba7f6e66a 100644 --- a/llmfoundry/callbacks/model_gauntlet_callback.py +++ b/llmfoundry/callbacks/model_gauntlet_callback.py @@ -1,10 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -"""Monitor gradients during training.""" +"""Aggregate ICL evals into composite scores.""" import math import re @@ -24,20 +21,45 @@ class Weighting(Enum): class ModelGauntlet(Callback): + """The ModelGauntlet aggregates ICL eval results. + + After `eval_end`, this callback inspects the logger for different ICL metrics and aggregates the scores according to the aggregation + specification provided in the constructor. + + Args: + logger_keys (dict): These are the exact keys that the individual benchmark metrics will be logged under in the logger after eval + tasks (dict): This contains the list of categories, as well as the subtasks within them, the random baseline accuracy of each subtask, and the number of fewshot examples + used for the task. See `llmfoundry/scripts/eval/yamls/model_gauntlet.yaml` to see the structure. + weighting (Weighting): The weighting scheme used to balance different tasks within each category. Either assign them all equal weight, assign them weight proportional to the dataset size, or assign them weight proportional to the log2 of the dataset size. + substract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy from the performance on each individual benchmark before aggregating. + rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0. + benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting. + """ def __init__(self, logger_keys: dict, - tasks: dict, + categories: dict, weighting: Weighting = Weighting.EQUAL, subtract_random_baseline: bool = True, rescale_accuracy: bool = True, benchmark_sizes: Optional[dict] = None): - self.tasks = tasks + if weighting != Weighting.EQUAL and benchmark_sizes is None: + raise Exception( + 'When not using equal weighting, you must provide the benchmark sizes.' + ) + + if rescale_accuracy and not subtract_random_baseline: + raise Exception( + 'Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.' + ) + + self.categories = categories self.weighting = Weighting[weighting] self.subtract_random_baseline = subtract_random_baseline self.rescale_accuracy = rescale_accuracy self.logger_keys = logger_keys - for category in self.tasks: + + for category in self.categories: for benchmark in category['benchmarks']: bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot" @@ -83,7 +105,7 @@ def compute_averages(self, logger_data): def eval_end(self, state: State, logger: Logger): new_metrics = self.compute_averages(logger) composite_scores = {} - for category in self.tasks: + for category in self.categories: composite_scores[category['name']] = [] for benchmark in category['benchmarks']: key_pat = re.compile( diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index 735edd8597..c2395896f8 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -8,22 +8,19 @@ integrations: command: | cd llm-foundry/scripts - pip install mosaicml@git+https://github.com/mosaicml/composer.git@dev composer eval/eval.py /mnt/config/parameters.yaml # Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME run_name: all-eval -gpu_num: 64 -gpu_type: a100_40gb -cluster: # replace with your cluster here! +gpu_num: 8 +gpu_type: a100_80gb +cluster: r1z1 # replace with your cluster here! image: mosaicml/llm-foundry:2.0.1_cu118-latest -scheduling: - priority: high # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: - dist_timeout: 60000 + dist_timeout: 6000 seed: 1 max_seq_len: 1024 device_eval_batch_size: 4 diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index db65f8658c..5956f7cd32 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -11,7 +11,6 @@ from composer.loggers import InMemoryLogger, LoggerDestination from composer.trainer import Trainer from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig from omegaconf import OmegaConf as om from llmfoundry.callbacks import ModelGauntlet @@ -20,22 +19,89 @@ build_tokenizer) -def load_model(model_cfg, tokenizer): +def load_model(model_cfg, tokenizer, num_retries): retries = 0 - while retries < 3: + while retries < num_retries: try: composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer) return composer_model except Exception as e: retries += 1 - if retries >= 3: + if retries >= num_retries: raise e else: print( - f'Got exception {str(e)} while loading model {model_cfg.name}. {3-retries} retries remaining' + f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) - return None + + +def evaluate_model(model_cfg): + print(f'Evaluating model: {model_cfg.model_name}', flush=True) + # Build tokenizer and model + tokenizer = build_tokenizer(model_cfg.tokenizer) + + evaluators, logger_keys = build_icl_evaluators(cfg.icl_tasks, tokenizer, + cfg.max_seq_len, + cfg.device_eval_batch_size) + + if hasattr(cfg, 'model_gauntlet'): + if isinstance(cfg.model_gauntlet, str): + with open(cfg.model_gauntlet, 'r') as icl_f: + model_gauntlet_cfg = om.load(icl_f) + model_gauntlet = model_gauntlet_cfg.model_gauntlet + else: + model_gauntlet = cfg.model_gauntlet + model_gauntlet.logger_keys = logger_keys + model_gauntlet.benchmark_sizes = { + e.label: e.dataloader.num_samples for e in evaluators + } + model_gauntlet_callback = ModelGauntlet(**model_gauntlet) + else: + model_gauntlet = None + + composer_model = load_model(model_cfg.model, tokenizer, + cfg.get('num_retries', 3)) + + if model_gauntlet_df is None and model_gauntlet is not None: + model_gauntlet_df = pd.DataFrame(columns=['model_name', 'average'] + + [t.name for t in model_gauntlet.tasks]) + + in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger + loggers: List[LoggerDestination] = [ + build_logger(name, logger_cfg) + for name, logger_cfg in (cfg.get('loggers') or {}).items() + ] + loggers.append(in_memory_logger) + + fsdp_config = cfg.get('fsdp_config', None) + fsdp_config = om.to_container( + fsdp_config, resolve=True) if fsdp_config is not None else None + + load_path = model_cfg.get('load_path', None) + + trainer = Trainer( + model=composer_model, + loggers=loggers, + precision=cfg.precision, + fsdp_config=fsdp_config, # type: ignore + load_path=load_path, + load_weights_only=True, + progress_bar=False, + log_to_console=True, + dist_timeout=cfg.dist_timeout, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + a = time.time() + trainer.eval(eval_dataloader=evaluators) + if torch.cuda.is_available(): + torch.cuda.synchronize() + b = time.time() + print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') + return (in_memory_logger, logger_keys, model_gauntlet_callback, + model_gauntlet) def main(cfg): @@ -47,72 +113,11 @@ def main(cfg): model_gauntlet_df = None models_df = None for model_cfg in cfg.models: - print(f'Evaluating model: {model_cfg.model_name}', flush=True) - # Build tokenizer and model - try: - tokenizer = build_tokenizer(model_cfg.tokenizer) - evaluators, logger_keys = build_icl_evaluators( - cfg.icl_tasks, tokenizer, cfg.max_seq_len, - cfg.device_eval_batch_size) + try: + (in_memory_logger, logger_keys, model_gauntlet_callback, + model_gauntlet) = evaluate_model() - if hasattr(cfg, 'model_gauntlet'): - if isinstance(cfg.model_gauntlet, str): - with open(cfg.model_gauntlet, 'r') as icl_f: - model_gauntlet_cfg = om.load(icl_f) - model_gauntlet = model_gauntlet_cfg.model_gauntlet - else: - model_gauntlet = cfg.model_gauntlet - model_gauntlet.logger_keys = logger_keys - model_gauntlet.benchmark_sizes = { - e.label: e.dataloader.num_samples for e in evaluators - } - model_gauntlet_callback = ModelGauntlet(**model_gauntlet) - else: - model_gauntlet = None - - composer_model = load_model(model_cfg.model, tokenizer) - - if model_gauntlet_df is None and model_gauntlet is not None: - model_gauntlet_df = pd.DataFrame( - columns=['model_name', 'average'] + - [t.name for t in model_gauntlet.tasks]) - - in_memory_logger = InMemoryLogger( - ) # track metrics in the in_memory_logger - loggers: List[LoggerDestination] = [ - build_logger(name, logger_cfg) - for name, logger_cfg in (cfg.get('loggers') or {}).items() - ] - loggers.append(in_memory_logger) - - fsdp_config = cfg.get('fsdp_config', None) - fsdp_config = om.to_container( - fsdp_config, resolve=True) if fsdp_config is not None else None - - load_path = model_cfg.get('load_path', None) - - trainer = Trainer( - model=composer_model, - loggers=loggers, - precision=cfg.precision, - fsdp_config=fsdp_config, # type: ignore - load_path=load_path, - load_weights_only=True, - progress_bar=False, - log_to_console=True, - dist_timeout=cfg.dist_timeout, - ) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - a = time.time() - trainer.eval(eval_dataloader=evaluators) - if torch.cuda.is_available(): - torch.cuda.synchronize() - b = time.time() - - print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') composite_scores = model_gauntlet_callback.eval_end( None, in_memory_logger) @@ -157,7 +162,6 @@ def main(cfg): print( f'Got exception: {str(e)} while evaluating {model_cfg}. Continuing to next model.', flush=True) - raise e def calculate_markdown_results(logger_keys, logger_data, benchmark_to_taxonomy, diff --git a/scripts/eval/yamls/hf_eval.yaml b/scripts/eval/yamls/hf_eval.yaml index 2e183b82cc..b7ae11b94a 100644 --- a/scripts/eval/yamls/hf_eval.yaml +++ b/scripts/eval/yamls/hf_eval.yaml @@ -1,6 +1,6 @@ max_seq_len: 2048 seed: 1 -precision: fp32 +precision: amp_fp16 models: @@ -21,9 +21,9 @@ models: device_eval_batch_size: 4 # FSDP config for model sharding -# fsdp_config: -# sharding_strategy: FULL_SHARD -# mixed_precision: FULL +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: FULL icl_tasks: 'eval/yamls/tasks_light.yaml' model_gauntlet: 'eval/yamls/model_gauntlet.yaml' diff --git a/scripts/eval/yamls/model_gauntlet.yaml b/scripts/eval/yamls/model_gauntlet.yaml index 1ac07ca6b9..08eb902405 100644 --- a/scripts/eval/yamls/model_gauntlet.yaml +++ b/scripts/eval/yamls/model_gauntlet.yaml @@ -8,10 +8,6 @@ model_gauntlet: - name: jeopardy num_fewshot: 10 random_baseline: 0 - # - name: triviaqa # not used in model gauntlet v0 - # num_fewshot: 0 - # scorecard: - # random_baseline: 0 - name: bigbench_qa_wikidata num_fewshot: 10 random_baseline: 0 @@ -110,13 +106,6 @@ model_gauntlet: - name: squad num_fewshot: 10 random_baseline: 0 - # - name: coqa # not used in model gauntlet v0 - # num_fewshot: 2 - # scorecard: - # size: 4 - # quality: 3 - # diversity: 4 - # random_baseline: 0 - name: bigbench_understanding_fables num_fewshot: 10 random_baseline: 0.25 diff --git a/scripts/eval/yamls/mpt_eval.yaml b/scripts/eval/yamls/mpt_eval.yaml index f32626e86e..658e547804 100644 --- a/scripts/eval/yamls/mpt_eval.yaml +++ b/scripts/eval/yamls/mpt_eval.yaml @@ -29,9 +29,9 @@ load_path: # Add your (optional) Composer checkpoint path here! device_eval_batch_size: 16 # FSDP config for model sharding -# fsdp_config: -# sharding_strategy: FULL_SHARD -# mixed_precision: FULL +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: FULL icl_tasks: - @@ -39,5 +39,5 @@ icl_tasks: dataset_uri: eval/local_data/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [0] icl_task_type: language_modeling - continuation_delimiter: '\nAnswer: ' # this separates questions from answers + continuation_delimiter: "\nAnswer: " # this separates questions from answers has_categories: true diff --git a/scripts/eval/yamls/tasks.yaml b/scripts/eval/yamls/tasks.yaml index 2353d87b99..70ef2ca667 100644 --- a/scripts/eval/yamls/tasks.yaml +++ b/scripts/eval/yamls/tasks.yaml @@ -6,11 +6,6 @@ icl_tasks: icl_task_type: language_modeling continuation_delimiter: "\nAnswer: " # this separates questions from answers has_categories: true -# - # not used in model gauntlet v0 -# label: triviaqa -# dataset_uri: eval/local_data/world_knowledge/triviaqa_sm.jsonl # ADD YOUR OWN DATASET URI -# num_fewshot: [0] -# icl_task_type: question_answering - label: bigbench_qa_wikidata dataset_uri: eval/local_data/world_knowledge/bigbench_qa_wikidata.jsonl # ADD YOUR OWN DATASET URI @@ -94,7 +89,7 @@ icl_tasks: - label: bigbench_conlang_translation dataset_uri: eval/local_data/language_understanding/bigbench_conlang_translation.jsonl - num_fewshot: [10] + num_fewshot: [0] icl_task_type: language_modeling - label: bigbench_language_identification @@ -167,11 +162,6 @@ icl_tasks: dataset_uri: eval/local_data/reading_comprehension/squad.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [10] icl_task_type: language_modeling -# - # not used in model gauntlet v0 -# label: coqa -# dataset_uri: eval/local_data/reading_comprehension/coqa.jsonl # ADD YOUR OWN DATASET URI -# num_fewshot: [2] -# icl_task_type: language_modeling - label: bigbench_understanding_fables dataset_uri: eval/local_data/reading_comprehension/bigbench_understanding_fables.jsonl # ADD YOUR OWN DATASET URI