Skip to content

Commit

Permalink
Update plot with some fixes (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 authored Oct 9, 2024
1 parent f2a9c96 commit 8ffc172
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
10 changes: 10 additions & 0 deletions analysis/_plot_leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def main():
category="Reasoning",
)

dfs = {
"chat": chat_leaderboard_df,
"chat-hard": chat_hard_leaderboard_df,
"safety": safety_leaderboard_df,
"reasoning": reasoning_leaderboard_df,
}

for k, v in dfs.items():
v.to_csv(f"{k}.csv")

# Save per model type
model_types = leaderboard_df["Type"].unique().tolist()
for model_type in model_types:
Expand Down
44 changes: 27 additions & 17 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,16 @@ def plot_main_heatmap(
data = df[[col for col in df.columns if col not in ["Model_Type"]]].rename(columns={"Avg_Multilingual": "Avg"})
data = data.set_index("Model")
data = data * 100
data["zho"] = data[["zho_Hans", "zho_Hant"]].mean(axis=1)
data.pop("zho_Hans")
data.pop("zho_Hant")
data = data[sorted(data.columns)]
data.columns = [col.split("_")[0] for col in data.columns]

fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 14}, fmt=".2f", cbar=False)
ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="x", rotation=45)
ax.tick_params(axis="x")
ax.set_ylabel("")
ax.set_yticklabels([f"{model} " for model in data.index])

Expand All @@ -108,30 +112,34 @@ def plot_eng_drop_line(
df = df[["Model", "Model_Type", "eng_Latn", "Avg_Multilingual"]]
df = df.sort_values(by="Avg_Multilingual", ascending=False).reset_index(drop=True)
data = df.set_index("Model").dropna()
data = data[["eng_Latn", "Avg_Multilingual"]] * 100
model_types = df.dropna().pop("Model_Type")
data[data.select_dtypes(include="number").columns] = data.select_dtypes(include="number") * 100
data["Model_Type"] = data["Model_Type"].replace({"DPO": "Implicit RM", "Sequence Classifier": "Classifier RM"})
if top_n:
logging.info(f"Showing top {top_n}")
data = data.head(top_n)
model_types = model_types[:top_n]

fig, ax = plt.subplots(figsize=figsize)

colors = ["red", "green", "blue"]
for (label, group), color in zip(data.groupby("Model_Type"), colors):
mrewardbench_scores = group["Avg_Multilingual"]
rewardbench_scores = group["eng_Latn"]
ax.scatter(rewardbench_scores, mrewardbench_scores, marker="o", s=30, label=label, color=color)

mrewardbench_scores = data["Avg_Multilingual"]
rewardbench_scores = data["eng_Latn"]
r, _ = pearsonr(mrewardbench_scores, rewardbench_scores)
res = spearmanr(mrewardbench_scores, rewardbench_scores)

colormap = {"Generative RM": "green", "Sequence Classifier": "blue", "DPO": "red"}
colors = [colormap[model_type] for model_type in model_types]

ax.scatter(rewardbench_scores, mrewardbench_scores, marker="o", s=30, color=colors)
# ax.scatter(rewardbench_scores, mrewardbench_scores, marker="o", s=30, color=colors, label=model_types)

min_val = min(rewardbench_scores.min(), mrewardbench_scores.min())
max_val = max(rewardbench_scores.max(), mrewardbench_scores.max())
ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="black", alpha=0.25)
ax.set_xlabel("RewardBench (Lambert et al., 2024)")
ax.set_ylabel("M-RewardBench")
ax.set_aspect("equal")
ax.legend(frameon=False, handletextpad=0.2, fontsize=12)

model_names = [model.split("/")[1] for model in data.index]
texts = [
Expand All @@ -150,15 +158,17 @@ def plot_eng_drop_line(
arrowprops=dict(arrowstyle="->", color="gray"),
)

ax.text(
0.1,
0.9,
s=f"Pearson-r: {r:.2f}\nSpearman-r: {res.statistic:.2f}",
fontsize=14,
transform=ax.transAxes,
verticalalignment="top",
bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.5"),
)
# ax.text(
# 0.6,
# 0.8,
# s=f"Pearson-r: {r:.2f} Spearman-r: {res.statistic:.2f}",
# fontsize=14,
# transform=ax.transAxes,
# verticalalignment="top",
# rotation=45,
# color="gray",
# # bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.5"),
# )

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
Expand Down

0 comments on commit 8ffc172

Please sign in to comment.