diff --git a/.gitignore b/.gitignore index 0e5028fb11..aff34b70f6 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ temp # IPython profile_default/ ipython_config.py +wandb +examples/wandb diff --git a/README.md b/README.md index 5ae743d91b..79455bfe60 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,10 @@ For a full list of supported arguments, check out the [interface](https://github ## Visualizing Results +You can seamlessly visualize and analyze the results of your evaluation harness runs using both Weights & Biases (W&B) and Zeno. + +### Zeno + You can use [Zeno](https://zenoml.com) to visualize the results of your eval harness runs. First, head to [hub.zenoml.com](https://hub.zenoml.com) to create an account and get an API key [on your account page](https://hub.zenoml.com/account). @@ -284,6 +288,41 @@ If you run the eval harness on multiple tasks, the `project_name` will be used a You can find an example of this workflow in [examples/visualize-zeno.ipynb](examples/visualize-zeno.ipynb). +### Weights and Biases + +With the [Weights and Biases](https://wandb.ai/site) integration, you can now spend more time extracting deeper insights into your evaluation results. The integration is designed to streamline the process of logging and visualizing experiment results using the Weights & Biases (W&B) platform. + +The integration provide functionalities + +- to automatically log the evaluation results, +- log the samples as W&B Tables for easy visualization, +- log the `results.json` file as an artifact for version control, +- log the `_eval_samples.json` file if the samples are logged, +- generate a comprehensive report for analysis and visualization with all the important metric, +- log task and cli specific configs, +- and more out of the box like the command used to run the evaluation, GPU/CPU counts, timestamp, etc. + +First you'll need to install the lm_eval[wandb] package extra. Do `pip install lm_eval[wandb]`. + +Authenticate your machine with an your unique W&B token. Visit https://wandb.ai/authorize to get one. Do `wandb login` in your command line terminal. + +Run eval harness as usual with a `wandb_args` flag. Use this flag to provide arguments for initializing a wandb run ([wandb.init](https://docs.wandb.ai/ref/python/init)) as comma separated string arguments. + +```bash +lm_eval \ + --model hf \ + --model_args pretrained=microsoft/phi-2,trust_remote_code=True \ + --tasks hellaswag,mmlu_abstract_algebra \ + --device cuda:0 \ + --batch_size 8 \ + --output_path output/phi-2 \ + --limit 10 \ + --wandb_args project=lm-eval-harness-integration \ + --log_samples +``` + +In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb). + ## How to Contribute or Learn More? For more information on the library and how everything fits together, check out all of our [documentation pages](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs)! We plan to post a larger roadmap of desired + planned library improvements soon, with more information on how contributors can help. diff --git a/examples/visualize-wandb.ipynb b/examples/visualize-wandb.ipynb new file mode 100644 index 0000000000..ed8df37741 --- /dev/null +++ b/examples/visualize-wandb.ipynb @@ -0,0 +1,130 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fc477b96-adee-4829-a9d7-a5eb990df358", + "metadata": {}, + "source": [ + "# Visualizing Results in Weights and Biases\n", + "\n", + "With the Weights and Biases integration, you can now spend more time extracting deeper insights into your evaluation results. The integration is designed to streamline the process of logging and visualizing experiment results using the Weights & Biases (W&B) platform.\n", + "\n", + "The integration provide functionalities\n", + "\n", + "- to automatically log the evaluation results,\n", + "- log the samples as W&B Tables for easy visualization,\n", + "- log the `results.json` file as an artifact for version control,\n", + "- log the `_eval_samples.json` file if the samples are logged,\n", + "- generate a comprehensive report for analysis and visualization with all the important metric,\n", + "- log task and cli configs,\n", + "- and more out of the box like the command used to run the evaluation, GPU/CPU counts, timestamp, etc.\n", + "\n", + "The integration is super easy to use with the eval harness. Let's see how!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3851439a-bff4-41f2-bf21-1b3d8704913b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Install this project if you did not already have it.\n", + "# This is all that is needed to be installed to start using Weights and Biases\n", + "\n", + "!pip -qq install -e ..[wandb]" + ] + }, + { + "cell_type": "markdown", + "id": "8507fd7e-3b99-4a92-89fa-9eaada74ba91", + "metadata": {}, + "source": [ + "# Run the Eval Harness\n", + "\n", + "Run the eval harness as usual with a `wandb_args` flag. This flag is used to provide arguments for initializing a wandb run ([wandb.init](https://docs.wandb.ai/ref/python/init)) as comma separated string arguments.\n", + "\n", + "If `wandb_args` flag is used, the metrics and all other goodness will be automatically logged to Weights and Biases. In the stdout, you will find the link to the W&B run page as well as link to the generated report." + ] + }, + { + "cell_type": "markdown", + "id": "eec5866e-f01e-42f8-8803-9d77472ef991", + "metadata": {}, + "source": [ + "## Set your API Key\n", + "\n", + "Before you can use W&B, you need to authenticate your machine with an authentication key. Visit https://wandb.ai/authorize to get one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d824d163-71a9-4313-935d-f1d56397841c", + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "wandb.login()" + ] + }, + { + "cell_type": "markdown", + "id": "124e4a34-1547-4bed-bc09-db012bacbda6", + "metadata": {}, + "source": [ + "> Note that if you are using command line you can simply authenticate your machine by doing `wandb login` in your terminal. For more info check out the [documentation](https://docs.wandb.ai/quickstart#2-log-in-to-wb)." + ] + }, + { + "cell_type": "markdown", + "id": "abc6f6b6-179a-4aff-ada9-f380fb74df6e", + "metadata": {}, + "source": [ + "## Run and log to W&B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd0a8130-a97b-451a-acd2-3f9885b88643", + "metadata": {}, + "outputs": [], + "source": [ + "!lm_eval \\\n", + " --model hf \\\n", + " --model_args pretrained=microsoft/phi-2,trust_remote_code=True \\\n", + " --tasks hellaswag,mmlu_abstract_algebra \\\n", + " --device cuda:0 \\\n", + " --batch_size 8 \\\n", + " --output_path output/phi-2 \\\n", + " --limit 10 \\\n", + " --wandb_args project=lm-eval-harness-integration \\\n", + " --log_samples" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lm_eval/__main__.py b/lm_eval/__main__.py index fc3c5857a4..8b02446d3b 100644 --- a/lm_eval/__main__.py +++ b/lm_eval/__main__.py @@ -11,6 +11,7 @@ import numpy as np from lm_eval import evaluator, utils +from lm_eval.logging_utils import WandbLogger from lm_eval.tasks import TaskManager, include_path, initialize_tasks from lm_eval.utils import make_table @@ -167,6 +168,11 @@ def parse_eval_args() -> argparse.Namespace: metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG", help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.", ) + parser.add_argument( + "--wandb_args", + default="", + help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval", + ) parser.add_argument( "--predict_only", "-x", @@ -195,6 +201,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: # we allow for args to be passed externally, else we parse them ourselves args = parse_eval_args() + if args.wandb_args: + wandb_logger = WandbLogger(args) + eval_logger = utils.eval_logger eval_logger.setLevel(getattr(logging, f"{args.verbosity}")) eval_logger.info(f"Verbosity set to {args.verbosity}") @@ -309,6 +318,16 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + # Add W&B logging + if args.wandb_args: + try: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + if args.output_path: output_path_file.open("w", encoding="utf-8").write(dumped) @@ -334,6 +353,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: if "groups" in results: print(make_table(results, "groups")) + if args.wandb_args: + # Tear down wandb run once all the logging is done. + wandb_logger.run.finish() + if __name__ == "__main__": cli_evaluate() diff --git a/lm_eval/logging_utils.py b/lm_eval/logging_utils.py new file mode 100644 index 0000000000..464ba9f732 --- /dev/null +++ b/lm_eval/logging_utils.py @@ -0,0 +1,386 @@ +import copy +import json +import logging +import re +from typing import Any, Dict, List, Literal, Tuple, Union + +import numpy as np +import pandas as pd +from packaging.version import Version + +from lm_eval import utils + + +logger = logging.getLogger(__name__) + +try: + import wandb + + assert Version(wandb.__version__) >= Version("0.13.6") + if Version(wandb.__version__) < Version("0.13.6"): + wandb.require("report-editing:v0") +except Exception as e: + logger.warning( + "To use the wandb reporting functionality please install wandb>=0.13.6.\n" + "To install the latest version of wandb run `pip install wandb --upgrade`\n" + f"{e}" + ) + + +def remove_none_pattern(input_string: str) -> Tuple[str, bool]: + """Remove the ',none' substring from the input_string if it exists at the end. + + Args: + input_string (str): The input string from which to remove the ',none' substring. + + Returns: + Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed + and a boolean indicating whether the modification was made (True) or not (False). + """ + # Define the pattern to match ',none' at the end of the string + pattern = re.compile(r",none$") + + # Use sub() to replace ',none' with an empty string + result = re.sub(pattern, "", input_string) + + # check if the input_string changed + removed = result != input_string + + return result, removed + + +def _handle_non_serializable(o: Any) -> Union[int, str, list]: + """Handle non-serializable objects by converting them to serializable types. + + Args: + o (Any): The object to be handled. + + Returns: + Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32, + it will be converted to int. If the object is of type set, it will be converted + to a list. Otherwise, it will be converted to str. + """ + if isinstance(o, np.int64) or isinstance(o, np.int32): + return int(o) + elif isinstance(o, set): + return list(o) + else: + return str(o) + + +def get_wandb_printer() -> Literal["Printer"]: + """Returns a wandb printer instance for pretty stdout.""" + from wandb.sdk.lib.printer import get_printer + from wandb.sdk.wandb_settings import Settings + + printer = get_printer(Settings()._jupyter) + return printer + + +class WandbLogger: + def __init__(self, args: Any) -> None: + """Initialize the WandbLogger. + + Args: + results (Dict[str, Any]): The results dictionary. + args (Any): Arguments for configuration. + """ + self.wandb_args: Dict[str, Any] = utils.simple_parse_args_string( + args.wandb_args + ) + + # initialize a W&B run + if wandb.run is None: + self.run = wandb.init(**self.wandb_args) + else: + self.run = wandb.run + + self.printer = get_wandb_printer() + + def post_init(self, results: Dict[str, Any]) -> None: + self.results: Dict[str, Any] = copy.deepcopy(results) + self.task_names: List[str] = list(results.get("results", {}).keys()) + self.group_names: List[str] = list(results.get("groups", {}).keys()) + + def _get_config(self) -> Dict[str, Any]: + """Get configuration parameters.""" + self.task_configs = self.results.get("configs", {}) + cli_configs = self.results.get("config", {}) + configs = { + "task_configs": self.task_configs, + "cli_configs": cli_configs, + } + + return configs + + def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]: + """Sanitize the results dictionary.""" + _results = copy.deepcopy(self.results.get("results", dict())) + + # Remove None from the metric string name + tmp_results = copy.deepcopy(_results) + for task_name in self.task_names: + task_result = tmp_results.get(task_name, dict()) + for metric_name, metric_value in task_result.items(): + _metric_name, removed = remove_none_pattern(metric_name) + if removed: + _results[task_name][_metric_name] = metric_value + _results[task_name].pop(metric_name) + + # remove string valued keys from the results dict + wandb_summary = {} + for task in self.task_names: + task_result = _results.get(task, dict()) + for metric_name, metric_value in task_result.items(): + if isinstance(metric_value, str): + wandb_summary[f"{task}/{metric_name}"] = metric_value + + for summary_metric, summary_value in wandb_summary.items(): + _task, _summary_metric = summary_metric.split("/") + _results[_task].pop(_summary_metric) + + tmp_results = copy.deepcopy(_results) + for task_name, task_results in tmp_results.items(): + for metric_name, metric_value in task_results.items(): + _results[f"{task_name}/{metric_name}"] = metric_value + _results[task_name].pop(metric_name) + for task in self.task_names: + _results.pop(task) + + return wandb_summary, _results + + def _log_results_as_table(self) -> None: + """Generate and log evaluation results as a table to W&B.""" + columns = [ + "Version", + "Filter", + "num_fewshot", + "Metric", + "Value", + "Stderr", + ] + + def make_table(columns: List[str], key: str = "results"): + table = wandb.Table(columns=columns) + results = copy.deepcopy(self.results) + + for k, dic in results.get(key).items(): + if k in self.group_names and not key == "groups": + continue + version = results.get("versions").get(k) + if version == "N/A": + version = None + n = results.get("n-shot").get(k) + + for (mf), v in dic.items(): + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + if m == "alias": + continue + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + if se != "N/A": + se = "%.4f" % se + table.add_data(*[k, version, f, n, m, str(v), str(se)]) + else: + table.add_data(*[k, version, f, n, m, str(v), ""]) + + return table + + # log the complete eval result to W&B Table + table = make_table(["Tasks"] + columns, "results") + self.run.log({"evaluation/eval_results": table}) + + if "groups" in self.results.keys(): + table = make_table(["Groups"] + columns, "groups") + self.run.log({"evaluation/group_eval_results": table}) + + def _log_results_as_artifact(self) -> None: + """Log results as JSON artifact to W&B.""" + dumped = json.dumps( + self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False + ) + artifact = wandb.Artifact("results", type="eval_results") + with artifact.new_file("results.json", mode="w", encoding="utf-8") as f: + f.write(dumped) + self.run.log_artifact(artifact) + + def log_eval_result(self) -> None: + """Log evaluation results to W&B.""" + # Log configs to wandb + configs = self._get_config() + self.run.config.update(configs) + + wandb_summary, self.wandb_results = self._sanitize_results_dict() + # update wandb.run.summary with items that were removed + self.run.summary.update(wandb_summary) + # Log the evaluation metrics to wandb + self.run.log(self.wandb_results) + # Log the evaluation metrics as W&B Table + self._log_results_as_table() + # Log the results dict as json to W&B Artifacts + self._log_results_as_artifact() + + def _generate_dataset( + self, data: List[Dict[str, Any]], config: Dict[str, Any] + ) -> pd.DataFrame: + """Generate a dataset from evaluation data. + + Args: + data (List[Dict[str, Any]]): The data to generate a dataset for. + config (Dict[str, Any]): The configuration of the task. + + Returns: + pd.DataFrame: A dataframe that is ready to be uploaded to W&B. + """ + ids = [x["doc_id"] for x in data] + labels = [x["target"] for x in data] + instance = [""] * len(ids) + resps = [""] * len(ids) + filtered_resps = [""] * len(ids) + model_outputs = {} + + metrics_list = config["metric_list"] + metrics = {} + for metric in metrics_list: + metric = metric.get("metric") + if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data] + if metric in ["byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_bytes"] = [x[metric][1] for x in data] + else: + metrics[f"{metric}_words"] = [x[metric][1] for x in data] + else: + metrics[metric] = [x[metric] for x in data] + + if config["output_type"] == "loglikelihood": + instance = [x["arguments"][0][0] for x in data] + labels = [x["arguments"][0][1] for x in data] + resps = [ + f'log probability of continuation is {x["resps"][0][0][0]} ' + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["resps"][0][0][1] else "be" + ) + for x in data + ] + filtered_resps = [ + f'log probability of continuation is {x["filtered_resps"][0][0]} ' + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["filtered_resps"][0][1] else "be" + ) + for x in data + ] + elif config["output_type"] == "multiple_choice": + instance = [x["arguments"][0][0] for x in data] + choices = [ + "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) + for x in data + ] + resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data] + filtered_resps = [ + np.argmax([n[0] for n in x["filtered_resps"]]) for x in data + ] + elif config["output_type"] == "loglikelihood_rolling": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + elif config["output_type"] == "generate_until": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + + model_outputs["raw_predictions"] = resps + model_outputs["filtered_predictions"] = filtered_resps + + df_data = { + "id": ids, + "data": instance, + } + if config["output_type"] == "multiple_choice": + df_data["choices"] = choices + + tmp_data = { + "input_len": [len(x) for x in instance], + "labels": labels, + "output_type": config["output_type"], + } + df_data.update(tmp_data) + df_data.update(model_outputs) + df_data.update(metrics) + + return pd.DataFrame(df_data) + + def _log_samples_as_artifact( + self, data: List[Dict[str, Any]], task_name: str + ) -> None: + # log the samples as an artifact + dumped = json.dumps( + data, + indent=2, + default=_handle_non_serializable, + ensure_ascii=False, + ) + artifact = wandb.Artifact(f"{task_name}", type="samples_by_task") + with artifact.new_file( + f"{task_name}_eval_samples.json", mode="w", encoding="utf-8" + ) as f: + f.write(dumped) + self.run.log_artifact(artifact) + # artifact.wait() + + def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None: + """Log evaluation samples to W&B. + + Args: + samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task. + """ + task_names: List[str] = [ + x for x in self.task_names if x not in self.group_names + ] + + ungrouped_tasks = [] + tasks_by_groups = {} + + for task_name in task_names: + group_names = self.task_configs[task_name].get("group", None) + if group_names: + if isinstance(group_names, str): + group_names = [group_names] + + for group_name in group_names: + if not tasks_by_groups.get(group_name): + tasks_by_groups[group_name] = [task_name] + else: + tasks_by_groups[group_name].append(task_name) + else: + ungrouped_tasks.append(task_name) + + for task_name in ungrouped_tasks: + eval_preds = samples[task_name] + + # log the samples as a W&B Table + df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) + self.run.log({f"{task_name}_eval_results": df}) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + for group, grouped_tasks in tasks_by_groups.items(): + grouped_df = pd.DataFrame() + for task_name in grouped_tasks: + eval_preds = samples[task_name] + df = self._generate_dataset( + eval_preds, self.task_configs.get(task_name) + ) + df["group"] = group + df["task"] = task_name + grouped_df = pd.concat([grouped_df, df], ignore_index=True) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + self.run.log({f"{group}_eval_results": grouped_df}) diff --git a/pyproject.toml b/pyproject.toml index ca66f8547c..63fd49be67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] testing = ["pytest", "pytest-cov", "pytest-xdist"] vllm = ["vllm<=0.2.5"] zeno = ["pandas", "zeno-client"] +wandb = ["wandb>=0.16.3", "pandas", "numpy"] all = [ "lm_eval[anthropic]", "lm_eval[dev]", @@ -86,6 +87,7 @@ all = [ "lm_eval[testing]", "lm_eval[vllm]", "lm_eval[zeno]", + "lm_eval[wandb]", ] [tool.ruff]