Skip to content

Commit

Permalink
Add plot for NLLB vs Google Translate (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 authored Oct 13, 2024
1 parent 8b8fa99 commit 64c72c0
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 66 deletions.
7 changes: 4 additions & 3 deletions analysis/_plot_leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download

from analysis.plot_utils import get_scores, PLOT_PARAMS
from analysis.plot_utils import PLOT_PARAMS, get_scores

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -99,7 +99,8 @@ def main():
output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.png"
csv_output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.csv"
data_to_cache = data.copy(deep=True)
data_to_cache["eng_Latn"] = model_type_df["eng_Latn"]
if "eng_Latn" in model_type_df.columns:
data_to_cache["eng_Latn"] = model_type_df["eng_Latn"]
data_to_cache = data_to_cache.rename(columns={"Avg": "Avg_Multilingual"})
data_to_cache.to_csv(csv_output_file)
fig.savefig(output_file, dpi=120)
Expand Down
2 changes: 1 addition & 1 deletion analysis/avg_agreement_final.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
COLORS = {"green": "#355145", "purple": "#d8a6e5", "orange": "#fe7759"}
Expand Down
105 changes: 49 additions & 56 deletions analysis/maple_results.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,87 @@
import json
from pathlib import Path

import argparse
import json
import logging
from collections import defaultdict
from itertools import combinations
from pathlib import Path
from typing import Optional

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download
import datasets
import json

import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations
from collections import defaultdict


FONT_SIZES = {"small": 12, "medium": 16, "large": 18}

PLOT_PARAMS = {
"font.family": "serif",
"font.serif": ["Times New Roman", "STIX"],
"font.size": FONT_SIZES.get("medium"),
"axes.titlesize": FONT_SIZES.get("large"),
"axes.labelsize": FONT_SIZES.get("large"),
"xtick.labelsize": FONT_SIZES.get("large"),
"ytick.labelsize": FONT_SIZES.get("small"),
"legend.fontsize": FONT_SIZES.get("medium"),
"figure.titlesize": FONT_SIZES.get("medium"),
"text.usetex": False,
"font.family": "serif",
"font.serif": ["Times New Roman", "STIX"],
"font.size": FONT_SIZES.get("medium"),
"axes.titlesize": FONT_SIZES.get("large"),
"axes.labelsize": FONT_SIZES.get("large"),
"xtick.labelsize": FONT_SIZES.get("large"),
"ytick.labelsize": FONT_SIZES.get("small"),
"legend.fontsize": FONT_SIZES.get("medium"),
"figure.titlesize": FONT_SIZES.get("medium"),
"text.usetex": False,
}

logging.basicConfig(level=logging.INFO)

plt.rcParams.update(PLOT_PARAMS)


def load_json(json_file_path):
with open(json_file_path, "r") as file:
json_data = json.load(file)
return json_data
with open(json_file_path, "r") as file:
json_data = json.load(file)
return json_data

results_dir = 'data/eval-results-maple'

results_dir = "data/eval-results-maple"
results_path = Path(results_dir)

results_all = []
for result_file in results_path.glob("*.json"):
raw_results = load_json(result_file)
if "leaderboard" in raw_results.keys():
model_id = raw_results["model"]
subset_results = raw_results['subset']
overall = raw_results['scores']['accuracy']
remove_key = ['model', 'model_type', 'chat_template']
for key in remove_key:
del subset_results[key]
elif "subset_results" in raw_results.keys():
model_id = raw_results["model"]
subset_results = raw_results['subset_results']
overall = raw_results['accuracy']
else:
model_id = raw_results["model"]
subset_results = raw_results['extra_results']
overall = raw_results['accuracy']
# print(model_id, overall)
# print("\t", subset_results)
# results_all.append([model_id, overall, subset_results])
results_all.append({'Model': model_id, 'Avg': overall, **subset_results})
# import ipdb; ipdb.set_trace()

TOP = 10
raw_results = load_json(result_file)
if "leaderboard" in raw_results.keys():
model_id = raw_results["model"]
subset_results = raw_results["subset"]
overall = raw_results["scores"]["accuracy"]
remove_key = ["model", "model_type", "chat_template"]
for key in remove_key:
del subset_results[key]
elif "subset_results" in raw_results.keys():
model_id = raw_results["model"]
subset_results = raw_results["subset_results"]
overall = raw_results["accuracy"]
else:
model_id = raw_results["model"]
subset_results = raw_results["extra_results"]
overall = raw_results["accuracy"]
# print(model_id, overall)
# print("\t", subset_results)
# results_all.append([model_id, overall, subset_results])
results_all.append({"Model": model_id, "Avg": overall, **subset_results})

# import ipdb; ipdb.set_trace()

TOP = 10
# results_all.sort(key=lambda x: x[1], reverse=True)
# results_all = results_all[:TOP]
# print(results_all)

df_results = pd.DataFrame(results_all)
df_results = df_results.sort_values(by='Avg', ascending=False).reset_index(drop=True)
df_results = df_results.sort_values(by="Avg", ascending=False).reset_index(drop=True)
df_results = df_results.head(10).reset_index(drop=True)

df_results.columns = df_results.columns.str.replace('^maple-', '', regex=True)
df_results.columns = df_results.columns.str.replace("^maple-", "", regex=True)
df_results = df_results.set_index("Model")
df_results = df_results * 100
fig, ax = plt.subplots(1, 1, figsize=(18, 5))

sns.heatmap(df_results, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16},
fmt=".1f", cbar=False)
sns.heatmap(df_results, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16}, fmt=".1f", cbar=False)

ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="x", labelrotation=45)
Expand All @@ -97,5 +92,3 @@ def load_json(json_file_path):

plt.savefig("plots/maple.pdf", bbox_inches="tight")
# import ipdb; ipdb.set_trace()


83 changes: 77 additions & 6 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import logging
from pathlib import Path
from inspect import signature
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from adjustText import adjust_text

FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
Expand Down Expand Up @@ -90,6 +90,11 @@ def get_args():
parser_ling_dims.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
parser_ling_dims.add_argument("--langdata", type=Path, required=True, help="Path to the language data file.")
parser_ling_dims.add_argument("--top_n", type=int, required=False, default=None, help="Aggregate only the scores for top-n.")

parser_translate = subparsers.add_parser("translate", help="Plot translation quality.", parents=[shared_args])
parser_translate.add_argument("--gtrans", type=Path, required=True, help="Path to the Google Translate results file.")
parser_translate.add_argument("--nllb", type=Path, required=True, help="Path to the NLLB-3.3B results file.")

# fmt: on
return parser.parse_args()

Expand All @@ -101,6 +106,7 @@ def main():
"main_heatmap": plot_main_heatmap,
"eng_drop_line": plot_eng_drop_line,
"ling_dims": plot_ling_dims,
"translate": plot_translate,
}

def _filter_args(func, kwargs):
Expand Down Expand Up @@ -254,8 +260,6 @@ def plot_eng_drop_line(
# # bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.5"),
# )

# ax.spines["right"].set_visible(False)
# ax.spines["top"].set_visible(False)
plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")

Expand Down Expand Up @@ -316,8 +320,75 @@ def plot_ling_dims(
ax.set_ylabel("")
ax.set_xlabel("M-RewardBench Score")

# ax.spines["right"].set_visible(False)
# ax.spines["top"].set_visible(False)
plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")


def plot_translate(
gtrans: Path,
nllb: Path,
output_path: Path,
figsize: Optional[tuple[int, int]] = (18, 5),
):
columns = ["Model", "Model_Type", "Avg_Multilingual"]
gtrans_df = pd.read_csv(gtrans)[columns].rename(columns={"Avg_Multilingual": "Avg_Gtrans"})
nllb_df = pd.read_csv(nllb)[columns].rename(columns={"Avg_Multilingual": "Avg_NLLB"})

combined = nllb_df.merge(gtrans_df, how="left", on="Model")
combined = combined[["Model", "Avg_NLLB", "Avg_Gtrans", "Model_Type_x"]].rename(
columns={"Model_Type_x": "Model_Type"}
)

print(combined.sort_values(by="Avg_NLLB", ascending=False))

colors = {
"Sequence Classifier": COLORS.get("green"),
"Generative RM": COLORS.get("purple"),
"DPO": COLORS.get("orange"),
}

labels = {
"Sequence Classifier": "Classifier RM",
"Generative RM": "Generative RM",
"DPO": "Implicit RM",
}

fig, ax = plt.subplots(figsize=figsize)
for _, row in combined.iterrows():
ax.plot(
[1, 2],
[row["Avg_NLLB"], row["Avg_Gtrans"]],
marker="o",
color=colors[row["Model_Type"]],
label=labels[row["Model_Type"]],
)

# Avoid duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(
by_label.values(),
by_label.keys(),
frameon=False,
ncols=3,
loc="lower center",
bbox_to_anchor=(0.5, -0.2),
)

# ax.grid(color="gray", alpha=0.2, which="both", axis="x")
# ax.set_ylabel("M-RewardBench Overall Score")

ax.set_xticks([1, 2])
ax.set_xticklabels(["NLLB", "Google Translate"])
ax.yaxis.set_visible(False)

ax.spines[["top", "bottom", "left", "right"]].set_visible(False)
ax.vlines(
[1, 2],
ymin=combined[["Avg_NLLB", "Avg_Gtrans"]].min().min(),
ymax=combined[["Avg_NLLB", "Avg_Gtrans"]].max().max(),
colors="gray",
)

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")
Expand Down

0 comments on commit 64c72c0

Please sign in to comment.