From 454c66a627b4869767209fd8197003fde8a58724 Mon Sep 17 00:00:00 2001 From: Lj Miranda <12949683+ljvmiranda921@users.noreply.github.com> Date: Wed, 9 Oct 2024 21:02:43 -0700 Subject: [PATCH] Plot updates (#44) --- analysis/plot_results.py | 97 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/analysis/plot_results.py b/analysis/plot_results.py index 6dfec9c..2f97f9f 100644 --- a/analysis/plot_results.py +++ b/analysis/plot_results.py @@ -28,6 +28,44 @@ logging.basicConfig(level=logging.INFO) +MODEL_STANDARDIZATION = { + "openai/gpt-4-turbo-2024-04-09": "GPT-4 Turbo", + "openai/gpt-4o-2024-05-13": "GPT-4o", + "google/gemma-2-9b-it": "Gemma 2 9B", + "LxzGordon/URM-LLaMa-3.1-8B": "URM LlaMa 3.1 8B", + "meta-llama/Meta-Llama-3.1-70B-Instruct": "Llama 3.1 70B", + "meta-llama/Meta-Llama-3-70B-Instruct": "Llama 3 70B", + "CIR-AMS/BTRM_Qwen2_7b_0613": "BTRM Qwen 2 7B", + "cohere/command-r-plus-08-2024": "Command R+", + "allenai/tulu-2-dpo-13b": "Tulu 2 13B DPO", + "cohere/c4ai-aya-23-35b": "Aya 23 35B", +} + +LANG_STANDARDIZATION = { + "arb": "ar", + "ces": "cs", + "deu": "de", + "ell": "el", + "fra": "fr", + "heb": "he", + "hin": "hi", + "ind": "id", + "ita": "it", + "jpn": "jp", + "kor": "kr", + "nld": "nl", + "pes": "fa", + "pol": "pl", + "por": "pt", + "ron": "ro", + "rus": "ru", + "spa": "es", + "tur": "tr", + "ukr": "uk", + "vie": "vi", + "zho": "zh", +} + def get_args(): # fmt: off @@ -45,6 +83,11 @@ def get_args(): parser_eng_drop = subparsers.add_parser("eng_drop_line", help="Plot english drop as a line chart.", parents=[shared_args]) parser_eng_drop.add_argument("--input_path", type=Path, required=True, help="Path to the results file.") parser_eng_drop.add_argument("--top_n", default=None, type=int, help="If set, will only show the .") + + parser_ling_dims = subparsers.add_parser("ling_dims", help="Plot performance with respect to linguistic dimensions.", parents=[shared_args]) + parser_ling_dims.add_argument("--input_path", type=Path, required=True, help="Path to the results file.") + parser_ling_dims.add_argument("--langdata", type=Path, required=True, help="Path to the language data file.") + parser_ling_dims.add_argument("--top_n", type=int, required=False, default=None, help="Aggregate only the scores for top-n.") # fmt: on return parser.parse_args() @@ -55,6 +98,7 @@ def main(): cmd_map = { "main_heatmap": plot_main_heatmap, "eng_drop_line": plot_eng_drop_line, + "ling_dims": plot_ling_dims, } def _filter_args(func, kwargs): @@ -81,6 +125,7 @@ def plot_main_heatmap( df = df.sort_values(by="Avg_Multilingual", ascending=False).head(10).reset_index(drop=True) data = df[[col for col in df.columns if col not in ["Model_Type"]]].rename(columns={"Avg_Multilingual": "Avg"}) + data["Model"] = data["Model"].replace(MODEL_STANDARDIZATION) data = data.set_index("Model") data = data * 100 data["zho"] = data[["zho_Hans", "zho_Hant"]].mean(axis=1) @@ -88,9 +133,10 @@ def plot_main_heatmap( data.pop("zho_Hant") data = data[sorted(data.columns)] data.columns = [col.split("_")[0] for col in data.columns] + data = data.rename(columns=LANG_STANDARDIZATION) fig, ax = plt.subplots(1, 1, figsize=figsize) - sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 14}, fmt=".2f", cbar=False) + sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16}, fmt=".2f", cbar=False) ax.xaxis.set_ticks_position("top") ax.tick_params(axis="x") ax.set_ylabel("") @@ -141,7 +187,7 @@ def plot_eng_drop_line( ax.set_aspect("equal") ax.legend(frameon=False, handletextpad=0.2, fontsize=12) - model_names = [model.split("/")[1] for model in data.index] + model_names = [MODEL_STANDARDIZATION[model] for model in data.index] texts = [ ax.text( rewardbench_scores[idx], @@ -182,5 +228,52 @@ def plot_eng_drop_line( print(delta_df.to_latex()) +def plot_ling_dims( + input_path: Path, + langdata: Path, + output_path: Path, + top_n: Optional[int] = None, + figsize: Optional[tuple[int, int]] = (18, 5), +): + raw = pd.read_csv(input_path).set_index("Model") + if top_n: + raw = raw.head(top_n) + raw = raw[[col for col in raw.columns if col not in ("Model_Type", "eng_Latn", "Avg_Multilingual")]] + raw = raw.T + langdata = pd.read_csv(langdata).set_index("Language") + combined = raw.merge(langdata, left_index=True, right_index=True) + combined["Avg"] = raw.mean(axis=1) * 100 + combined["Std"] = raw.std(axis=1) * 100 + + combined = combined.rename(columns={"Resource_Type": "Resource Availability"}) + linguistic_dims = [ + "Resource Availability", + "Family", + "Script", + ] + fig, axs = plt.subplots(1, len(linguistic_dims), figsize=figsize, sharex=True) + for ax, dim in zip(axs, linguistic_dims): + lingdf = combined.groupby(dim).agg({"Avg": "mean", "Std": "mean"}).reset_index() + + sns.barplot( + x="Avg", + y=dim, + data=lingdf, + ax=ax, + color="green", + width=0.5 if dim == "Resource Availability" else 0.7, + ) + ax.set_title(dim) + ax.set_xlim([60, 70]) + ax.set_ylabel("") + ax.set_xlabel("M-RewardBench Score") + + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + + plt.tight_layout() + fig.savefig(output_path, bbox_inches="tight") + + if __name__ == "__main__": main()