diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index 16a50a31a9..b876826e3c 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -9,6 +9,7 @@ from composer.callbacks.activation_monitor import ActivationMonitor from composer.callbacks.checkpoint_saver import CheckpointSaver from composer.callbacks.early_stopper import EarlyStopper +from composer.callbacks.eval_output_logging_callback import EvalOutputLogging from composer.callbacks.export_for_inference import ExportForInferenceCallback from composer.callbacks.free_outputs import FreeOutputs from composer.callbacks.generate import Generate @@ -35,6 +36,7 @@ 'CheckpointSaver', 'MLPerfCallback', 'EarlyStopper', + 'EvalOutputLogging', 'ExportForInferenceCallback', 'ThresholdStopper', 'ImageVisualizer', diff --git a/composer/callbacks/eval_output_logging_callback.py b/composer/callbacks/eval_output_logging_callback.py new file mode 100644 index 0000000000..6334572410 --- /dev/null +++ b/composer/callbacks/eval_output_logging_callback.py @@ -0,0 +1,115 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Log model outputs and expected outputs during ICL evaluation.""" + +import warnings +from copy import deepcopy +from typing import Any, Dict, List, Sequence, Union + +import torch + +from composer.core import Callback, State +from composer.loggers import ConsoleLogger, Logger +from composer.utils.dist import all_gather_object + + +class EvalOutputLogging(Callback): + """Logs eval outputs for each sample of each ICL evaluation dataset. + + ICL metrics are required to support caching the model's responses including information on whether model was correct. + Metrics are responsible for returning the results of individual datapoints in a dictionary of lists. + The callback will log the metric name, the depadded and detokenized input, any data stored in state.metric_outputs, and + any keys from the batch pased into `batch_keys_to_log`. It will do so after every eval batch. + """ + + def __init__(self, log_tokens=False, *args, **kwargs): + super().__init__(self, *args, **kwargs) + self.log_tokens = log_tokens + self.columns = None + self.name = None + self.rows = [] + + def eval_batch_end(self, state: State, logger: Logger) -> None: + if not isinstance(state.batch, Dict): + warnings.warn( + f'''EvalOutputLogging only supports batches that are dictionary. \ + Found batch for type {type(state.batch)}. \ + Not logging eval outputs.''', + ) + return + + assert state.outputs is not None + assert state.metric_outputs is not None + logging_dict: Dict[str, Union[List[Any], torch.Tensor, Sequence[torch.Tensor]]] = deepcopy(state.metric_outputs) + + # If batch mode is not generate, outputs will be logits + if state.batch['mode'] == 'generate': + # Outputs are already detokenized + logging_dict['outputs'] = state.outputs + + input_ids = state.batch['input_ids'] + logged_input = [] + assert state.dataloader is not None + + # Depad and decode input_ids + for input_list in input_ids.tolist(): + dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] + depadded_input = [tok for tok in input_list if tok != dataset.pad_tok_id] + logged_input.append(dataset.tokenizer.decode(depadded_input)) + logging_dict['input'] = logged_input + + # Log token indices if toggled + if self.log_tokens: + logging_dict['input_tokens'] = input_ids.tolist() + if not state.batch['mode'] == 'generate': + if isinstance(state.outputs, torch.Tensor): # pyright + logging_dict['label_tokens'] = state.outputs.tolist() + + # Add run_name as a column + run_name_list = [state.run_name for _ in range(0, len(logging_dict['input']))] + logging_dict['run_name'] = run_name_list + + # NOTE: This assumes _any_ tensor logged are tokens to be decoded. + # This might not be true if, for example, logits are logged. + + # Detokenize data in rows + for key, value in logging_dict.items(): + # All types in list are the same + if isinstance(value[0], torch.Tensor): + logging_dict[key] = [ + state.dataloader.dataset.tokenizer.decode(t) # pyright: ignore[reportGeneralTypeIssues] + for t in value + ] + elif isinstance(value[0], list): + if isinstance(value[0][0], torch.Tensor): + tokenizer = state.dataloader.dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] + logging_dict[key] = [[tokenizer.decode(choice) for choice in t] for t in value] + + # Convert logging_dict from kv pairs of column name and column values to a list of rows + # Example: + # logging_dict = {"a": ["1a", "2a"], "b": ["1b", "2b"]} + # will become + # columns = {"a", "b"}, rows = [["1a", "1b"], ["2a", "2b"]] + columns = list(logging_dict.keys()) + rows = [list(item) for item in zip(*logging_dict.values())] + + assert state.dataloader_label is not None + if not self.name: + # If only running eval, step will be 0 + # If running training, step will be current training step + step = state.timestamp.batch.value + self.name = f'{state.dataloader_label}_step_{step}' + self.columns = columns + self.rows.extend(rows) + + def eval_end(self, state: State, logger: Logger) -> None: + list_of_rows = all_gather_object(self.rows) + rows = [row for rows in list_of_rows for row in rows] + for dest_logger in logger.destinations: + if not isinstance(dest_logger, ConsoleLogger): + dest_logger.log_table(self.columns, rows, name=self.name, step=state.timestamp.batch.value) + + self.rows = [] + self.name = None + self.columns = None diff --git a/composer/core/state.py b/composer/core/state.py index e23fad6f0d..10790a25ce 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -549,6 +549,8 @@ def __init__( self.eval_metric_values: Dict[str, float] = {} self.total_loss_dict: Dict[str, float] = {} + self.metric_outputs: Dict[str, Any] = {} + def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]: """Get the dataset contained by the given dataloader-like object. diff --git a/composer/loggers/in_memory_logger.py b/composer/loggers/in_memory_logger.py index f75c34001b..bd445f3cca 100644 --- a/composer/loggers/in_memory_logger.py +++ b/composer/loggers/in_memory_logger.py @@ -87,8 +87,11 @@ def log_table( conda_package='pandas', conda_channel='conda-forge', ) from e - table = pd.DataFrame.from_records(data=rows, columns=columns).to_json(orient='split', index=False) - assert isinstance(table, str) + table = pd.DataFrame.from_records(data=rows, + columns=columns).to_json(orient='split', index=False, force_ascii=False) + assert table is not None + # Merged assert is different + # assert isinstance(table, str) self.tables[name] = table def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: diff --git a/composer/loggers/wandb_logger.py b/composer/loggers/wandb_logger.py index ea7d9ed0b6..ad9af75aab 100644 --- a/composer/loggers/wandb_logger.py +++ b/composer/loggers/wandb_logger.py @@ -112,6 +112,8 @@ def __init__( self.run_dir: Optional[str] = None self.run_url: Optional[str] = None + self.table_dict = {} + def _set_is_in_atexit(self): self._is_in_atexit = True @@ -130,7 +132,7 @@ def log_table( if self._enabled: import wandb table = wandb.Table(columns=columns, rows=rows) - wandb.log({name: table}, step) + wandb.log({name: table}, step=step) def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: if self._enabled: diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 1af2e81ab8..4ac0fb9ad8 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -3,12 +3,14 @@ """A collection of common torchmetrics for NLP tasks.""" +import copy +import functools import logging import os import re import string import warnings -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import numpy as np import torch @@ -203,6 +205,38 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.needs_batch = True + def _wrap_update(self, update: Callable) -> Callable: + """Overwrite default _wrap_update to return result of update(). + + Torch metrics wraps update with following wrapped_func but explicitly does not return the value. + In general, torchmetrics update() does not return a value, but we want to in order to pass it on + to state.metric_outputs. + """ + + @functools.wraps(update) + def wrapped_func(*args: Any, **kwargs: Any) -> None: + self._computed = None + self._update_count += 1 + with torch.set_grad_enabled(self._enable_grad): + try: + update_result = update(*args, **kwargs) + except RuntimeError as err: + if 'Expected all tensors to be on' in str(err): + raise RuntimeError( + 'Encountered different devices in metric calculation (see stacktrace for details).' + ' This could be due to the metric class not being on the same device as input.' + f' Instead of `metric={self.__class__.__name__}(...)` try to do' + f' `metric={self.__class__.__name__}(...).to(device)` where' + ' device corresponds to the device of the input.', + ) from err + raise err + + if self.compute_on_cpu: + self._move_list_states_to_cpu() + return update_result + + return wrapped_func + def update( self, batch: dict, @@ -280,6 +314,12 @@ def __init__(self, dist_sync_on_step: bool = False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') + self.metric_result_dict = { + 'cleaned_output': [], + 'original_label': [], + 'cleaned_label': [], + 'result': [], + } def normalize_answer(self, answer: str): """Lower text and remove punctuation, articles and extra whitespace. @@ -309,6 +349,7 @@ def update(self, outputs: List[str], labels: List[List[str]], batch: Dict[str, A cot_delimiter = batch.get('cot_delimiter', '') do_normalization = batch.get('do_normalization', True) stopping_criteria = batch.get('stopping_criteria', None) + metric_result_dict = copy.deepcopy(self.metric_result_dict) for sample_output, sample_labels in zip(outputs, labels): final_answer = sample_output @@ -326,10 +367,20 @@ def update(self, outputs: List[str], labels: List[List[str]], batch: Dict[str, A cleaned_final_answer = final_answer.strip() cleaned_sample_labels = {sample_label.strip() for sample_label in sample_labels} + metric_result_dict['original_label'].append(sample_labels) + metric_result_dict['cleaned_output'].append(cleaned_final_answer) + metric_result_dict['cleaned_label'].append(cleaned_sample_labels) + if any(cleaned_final_answer.startswith(label) for label in cleaned_sample_labels): self.correct += torch.tensor(1.0) + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + self.total += torch.tensor(1.0) + return metric_result_dict + def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) @@ -365,6 +416,7 @@ def __init__(self, dist_sync_on_step: bool = False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') + self.metric_result_dict = {'context': [], 'label': [], 'output': [], 'result': []} def update( self, @@ -380,13 +432,23 @@ def update( outputs=outputs, ) + metric_result_dict = copy.deepcopy(self.metric_result_dict) for batch_idx, cont_idx in enumerate(batch['continuation_indices']): cont_tok_pred = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1) cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1) - self.correct += (cont_tok_pred == cont_tok_targ).all().int() + metric_result_dict['context'].append(batch['input_ids'][batch_idx][:cont_idx[0]]) + metric_result_dict['label'].append(cont_tok_targ) + metric_result_dict['output'].append(cont_tok_pred) + + correct = (cont_tok_pred == cont_tok_targ).all().int() + self.correct += correct + metric_result_dict['result'].append(int(correct)) + self.total += torch.tensor(1.0) + return metric_result_dict + def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) @@ -420,6 +482,15 @@ def __init__(self, dist_sync_on_step: bool = False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('correct', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.metric_result_dict = { + 'context': [], + 'correct_choice': [], + 'correct_choice_idx': [], + 'selected_choice': [], + 'selected_choice_idx': [], + 'all_choices': [], + 'result': [], + } def update( self, @@ -445,14 +516,41 @@ def update( perplexity = torch.exp(cross_entropy) perplexities.append(perplexity) + metric_result_dict = copy.deepcopy(self.metric_result_dict) for (start, end), gold_idx in zip(batch['choice_groupings'], batch['gold_indices']): subset = perplexities[start:end] idx_min = subset.index(min(subset)) - if idx_min == gold_idx: self.correct += torch.tensor(1.0) + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) + + question = batch['input_ids'][start][:batch['continuation_indices'][start][0]] + + correct_choice = batch['input_ids'][start:end][gold_idx][batch['continuation_indices'][start:end][gold_idx][ + 0]:batch['continuation_indices'][start:end][gold_idx][-1] + 1] + selected_choice = batch['input_ids'][start:end][idx_min][batch['continuation_indices'][start:end][idx_min][ + 0]:batch['continuation_indices'][start:end][idx_min][-1] + 1] + metric_result_dict['context'].append(question) + metric_result_dict['correct_choice'].append(correct_choice) + metric_result_dict['correct_choice_idx'].append(gold_idx) + metric_result_dict['selected_choice'].append(selected_choice) + metric_result_dict['selected_choice_idx'].append(idx_min) + all_choices = batch['input_ids'][start:end] + # Unpads the choices. Necessary in case different choices have different token lengths. + if 'attention_mask' in batch: + all_choices_list = [choice[batch['attention_mask'][i]] for i, choice in enumerate(all_choices)] + metric_result_dict['all_choices'].append(all_choices_list) + self.total += torch.tensor(1.0) + # Don't return all_choices if we didn't fill it up (i.e. didn't use causal lms) + if metric_result_dict['all_choices'] == []: + metric_result_dict.pop('all_choices') + + return metric_result_dict + def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) @@ -632,6 +730,8 @@ def __init__(self, dist_sync_on_step: bool = False): if self.eval_device is not None: self.eval_device = self.eval_device.upper() + self.metric_result_dict = {'context': [], 'output': [], 'result': [], 'sample_id': []} + def get_client(self) -> EvalClient: """Returns a client for the appropriate remote platform.""" client = None @@ -716,6 +816,7 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]): del labels # never used client = self.get_client() + metric_result_dict = copy.deepcopy(self.metric_result_dict) for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip( batch['sample_id'], outputs, @@ -728,9 +829,12 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]): idx = sample_id self.total[idx] += 1.0 + metric_result_dict['sample_id'].append(sample_id) code_gen = re.split(r'\n[A-Za-z0-9#`]', code_gen)[0] # remove everything after function ends final_code = sample_prompt + code_gen # combine prompt with the code generation + metric_result_dict['context'].append(sample_prompt) + metric_result_dict['output'].append(code_gen) test_results = [] for test_input, test_output in zip(test_inputs, test_outputs): @@ -747,8 +851,12 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]): if all(test_results): self.correct[idx] += 1.0 + metric_result_dict['result'].append(1) + else: + metric_result_dict['result'].append(0) client.close() # pyright: ignore [reportOptionalMemberAccess] + return metric_result_dict def compute(self): assert isinstance(self.correct, Tensor) diff --git a/composer/models/base.py b/composer/models/base.py index 47a3a3c3a5..0acc6ff539 100644 --- a/composer/models/base.py +++ b/composer/models/base.py @@ -185,13 +185,16 @@ def update_metric( batch: Any, outputs: Any, metric: Metric, - ) -> None: + ) -> Optional[Dict]: """Update the given metric. Args: batch: The dataloader batch outputs: The output from :meth:`eval_forward` metric (Metric): The metric to update. + + Returns: + Optional[Dict]: Optionally return metric results to be stored in state. """ raise NotImplementedError() diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index cec9830955..32d5d50902 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -589,11 +589,17 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]: return metrics if metrics else {} - def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None: + def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> Dict: if getattr(metric, 'needs_batch', False): - metric.update(batch=batch, outputs=outputs, labels=self.labels) + metric_result = metric.update(batch=batch, outputs=outputs, labels=self.labels) else: - metric.update(outputs, self.labels) + metric_result = metric.update(outputs, self.labels) + if metric_result is not None: + # Add the metric name once for each datapoint in the batch + metric_result['metric_name'] = [metric.__class__.__name__ for _ in range(0, batch['input_ids'].shape[0])] + else: + metric_result = {} + return metric_result def get_metadata(self): model_output = {} diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e6147644f8..47d381d067 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3325,11 +3325,12 @@ def _eval_loop( outputs = self.state.outputs for metric in metrics.values(): - self._original_model.update_metric( + metric_outputs = self._original_model.update_metric( self.state.batch, outputs, metric, ) + self.state.metric_outputs = metric_outputs or {} except RuntimeError as e: if evaluator.auto_microbatching and _is_cuda_oom(e): diff --git a/tests/callbacks/test_eval_output_logging_callback.py b/tests/callbacks/test_eval_output_logging_callback.py new file mode 100644 index 0000000000..936ac7bd01 --- /dev/null +++ b/tests/callbacks/test_eval_output_logging_callback.py @@ -0,0 +1,278 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import json + +import torch +from torch.utils.data import DataLoader + +from composer.callbacks import EvalOutputLogging +from composer.core.state import State +from composer.core.time import Timestamp +from composer.datasets.in_context_learning_evaluation import InContextLearningMultipleChoiceTaskDataset +from composer.loggers import InMemoryLogger, Logger +from composer.metrics.nlp import InContextLearningLMAccuracy, InContextLearningMultipleChoiceAccuracy +from tests.common import device + + +class MockDataset(InContextLearningMultipleChoiceTaskDataset): + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.pad_tok_id = tokenizer.pad_token_id + + +class MockDataLoader(DataLoader): + + def __init__(self, tokenizer): + self.dataset = MockDataset(tokenizer) + + +class MockState(State): + + def __init__(self) -> None: + self.eval_metrics = {} + self.metric_outputs = {} + self.run_name = 'mock_name' + self.timestamp = Timestamp() + + def add_metric(self, metric_name, metric): + self.eval_metrics[metric_name] = {} + self.eval_metrics[metric_name][str(metric)] = metric + + def update_curr_eval(self, dataloader, dataloader_label): + self._dataloader = dataloader + self._dataloader_label = dataloader_label + + +def mock_lm_computation(metric, tokenizer, state): + contexts = ['The dog is', 'I love to eat', 'I hate', 'The weather is'] + continuations = [' furry', ' pie', ' long lines', ' snowy'] + pad = tokenizer.pad_token_id + inputs = [ + tokenizer(context)['input_ids'] + tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor([input + [pad] * (2048 - len(input)) for input in inputs]) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tokenizer(context)['input_ids']) + end = start + len(tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = {'mode': 'icl_task', 'continuation_indices': cont_idxs, 'labels': inputs.roll(-1), 'input_ids': inputs} + logits = torch.nn.functional.one_hot(inputs.roll(-1), num_classes=pad + 1).float() * 100 + start, end = cont_idxs[1].tolist()[0] - 1, cont_idxs[1].tolist()[-1] + logits[1][start:end] = logits[0][start:end].clone() # make one of the answer's continuations incorrect + + state.metric_outputs = metric.update(batch, logits, batch['labels']) + metric.compute() + state.batch = batch + state.outputs = logits + return state + + +def mock_mc_computation(metric, tokenizer, state): + contexts = [ + 'Q: How do you cook a cake?', + 'Q: How do you cook a cake?', + 'Q: How old is the earth?', + 'Q: How old is the earth?', + ] + continuations = [' A: turn on the oven', ' A: do a backflip', ' A: 2 minutes', ' A: 4.5 billion years'] + gold_indices = [0, 1] + choice_groupings = [(0, 2), (2, 4)] + pad = tokenizer.pad_token_id + inputs = [ + tokenizer(context)['input_ids'] + tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor([input + [pad] * (2048 - len(input)) for input in inputs]) + attention_mask = ~(inputs == pad) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tokenizer(context)['input_ids']) + end = start + len(tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = { + 'mode': 'icl_task', + 'continuation_indices': cont_idxs, + 'labels': inputs.roll(-1), + 'input_ids': inputs, + 'attention_mask': attention_mask, + 'gold_indices': gold_indices, + 'choice_groupings': choice_groupings, + } + logits = torch.nn.functional.one_hot(inputs.roll(-1), num_classes=pad + 1).float() + + # for the first two, the correct answer is continuation 0 + # make the answer correct by making continuation 0 more likely for both answers + start, end = cont_idxs[1].tolist()[0] - 1, cont_idxs[1].tolist()[-1] + logits[1][start:end] = logits[0][start:end].clone() + + # for the last two, the correct answer is continuation 3 + # make the answer incorrect by making continuation 2 more likely for both answers + start, end = cont_idxs[3].tolist()[0], cont_idxs[3].tolist()[-1] + logits[3][start:end] = logits[2][start:end].clone() + + state.metric_outputs = metric.update(batch=batch, output_logits=logits, labels=batch['labels']) + state.batch = batch + state.outputs = logits + metric.compute() + + +@device('cpu') +def test_eval_output_logging_lm(device, tiny_gpt2_tokenizer): + # this test simulates an unrolled version of the eval loop occurring twice + state = MockState() + in_memory_logger = InMemoryLogger() + logger = Logger(state, in_memory_logger) + lm_metric = InContextLearningLMAccuracy() + + state.add_metric('lm_acc', lm_metric) + + # Construct the callback + eval_output_logging = EvalOutputLogging(loggers_to_use=['InMemoryLogger']) + + for _ in range(2): + state.update_curr_eval( + MockDataLoader(tiny_gpt2_tokenizer), + 'lm_acc', + ) + mock_lm_computation(state.eval_metrics['lm_acc']['InContextLearningLMAccuracy()'], tiny_gpt2_tokenizer, state) + state.metric_outputs['metric_name'] = [ + lm_metric.__class__.__name__ for _ in range(0, state.batch['input_ids'].shape[0]) + ] + eval_output_logging.eval_batch_end(state, logger) + state.timestamp = Timestamp(batch=state.timestamp.batch.value + 1) + eval_output_logging.eval_end(state, logger) + + assert f'lm_acc_step_0' in in_memory_logger.tables + # Only want one table - we log once to a single step value during eval_end() + assert len(in_memory_logger.tables) == 1 + assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['columns'] == [ + 'context', + 'label', + 'output', + 'result', + 'metric_name', + 'input', + 'run_name', + ] + # We use the same data in each batch + assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['data'] == [ + ['The dog is', ' furry', ' furry', 1, 'InContextLearningLMAccuracy', 'The dog is furry', 'mock_name'], + ['I love to eat', ' pie', '[PAD]', 0, 'InContextLearningLMAccuracy', 'I love to eat pie', 'mock_name'], + ['I hate', ' long lines', ' long lines', 1, 'InContextLearningLMAccuracy', 'I hate long lines', 'mock_name'], + ['The weather is', ' snowy', ' snowy', 1, 'InContextLearningLMAccuracy', 'The weather is snowy', 'mock_name'], + ['The dog is', ' furry', ' furry', 1, 'InContextLearningLMAccuracy', 'The dog is furry', 'mock_name'], + ['I love to eat', ' pie', '[PAD]', 0, 'InContextLearningLMAccuracy', 'I love to eat pie', 'mock_name'], + ['I hate', ' long lines', ' long lines', 1, 'InContextLearningLMAccuracy', 'I hate long lines', 'mock_name'], + ['The weather is', ' snowy', ' snowy', 1, 'InContextLearningLMAccuracy', 'The weather is snowy', 'mock_name'], + ] + + +@device('cpu') +def test_eval_output_logging_mc(device, tiny_gpt2_tokenizer): + # this test simulates an unrolled version of the eval loop occurring twice + state = MockState() + in_memory_logger = InMemoryLogger() + logger = Logger(state, in_memory_logger) + mc_metric = InContextLearningMultipleChoiceAccuracy() + + state.add_metric('mc_acc', mc_metric) + + # Construct the callback + eval_output_logging = EvalOutputLogging(loggers_to_use=['InMemoryLogger']) + for _ in range(2): + state.update_curr_eval( + MockDataLoader(tiny_gpt2_tokenizer), + 'mc_acc', + ) + mock_mc_computation( + state.eval_metrics['mc_acc']['InContextLearningMultipleChoiceAccuracy()'], + tiny_gpt2_tokenizer, + state, + ) + state.metric_outputs['metric_name'] = [ + mc_metric.__class__.__name__ for _ in range(0, state.batch['input_ids'].shape[0]) + ] + eval_output_logging.eval_batch_end(state, logger) + state.timestamp = Timestamp(batch=state.timestamp.batch.value + 1) + eval_output_logging.eval_end(state, logger) + + assert f'mc_acc_step_0' in in_memory_logger.tables + # Only want one table - we log once to a single step value during eval_end() + assert len(in_memory_logger.tables) == 1 + assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['columns'] == [ + 'context', + 'correct_choice', + 'correct_choice_idx', + 'selected_choice', + 'selected_choice_idx', + 'all_choices', + 'result', + 'metric_name', + 'input', + 'run_name', + ] + # We use the same data for each batch + assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['data'] == [ + [ + 'Q: How do you cook a cake?', + ' A: turn on the oven', + 0, + ' A: turn on the oven', + 0, + ['Q: How do you cook a cake? A: turn on the oven', 'Q: How do you cook a cake? A: do a backflip'], + 1, + 'InContextLearningMultipleChoiceAccuracy', + 'Q: How do you cook a cake? A: turn on the oven', + 'mock_name', + ], + [ + 'Q: How old is the earth?', + ' A: 4.5 billion years', + 1, + ' A: 2 minutes', + 0, + [ + 'Q: How old is the earth? A: 2 minutes[PAD][PAD][PAD]', + 'Q: How old is the earth? A: 4.5 billion years[PAD]', + ], + 0, + 'InContextLearningMultipleChoiceAccuracy', + 'Q: How do you cook a cake? A: do a backflip', + 'mock_name', + ], + [ + 'Q: How do you cook a cake?', + ' A: turn on the oven', + 0, + ' A: turn on the oven', + 0, + ['Q: How do you cook a cake? A: turn on the oven', 'Q: How do you cook a cake? A: do a backflip'], + 1, + 'InContextLearningMultipleChoiceAccuracy', + 'Q: How do you cook a cake? A: turn on the oven', + 'mock_name', + ], + [ + 'Q: How old is the earth?', + ' A: 4.5 billion years', + 1, + ' A: 2 minutes', + 0, + [ + 'Q: How old is the earth? A: 2 minutes[PAD][PAD][PAD]', + 'Q: How old is the earth? A: 4.5 billion years[PAD]', + ], + 0, + 'InContextLearningMultipleChoiceAccuracy', + 'Q: How do you cook a cake? A: do a backflip', + 'mock_name', + ], + ] diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 0a206135b3..151b846bec 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -387,6 +387,7 @@ def test_in_context_learning_mc_accuracy(tiny_gpt2_tokenizer): for context, continuation in zip(contexts, continuations) ] inputs = torch.tensor([input + [pad] * (2048 - len(input)) for input in inputs]) + attention_mask = ~(inputs == pad) cont_idxs = [] for context, continuation in zip(contexts, continuations): @@ -398,6 +399,7 @@ def test_in_context_learning_mc_accuracy(tiny_gpt2_tokenizer): 'continuation_indices': cont_idxs, 'labels': inputs.roll(-1), 'input_ids': inputs, + 'attention_mask': attention_mask, 'gold_indices': gold_indices, 'choice_groupings': choice_groupings, }