Skip to content

Commit

Permalink
Plot updates (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 authored Oct 10, 2024
1 parent 8ffc172 commit 454c66a
Showing 1 changed file with 95 additions and 2 deletions.
97 changes: 95 additions & 2 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,44 @@

logging.basicConfig(level=logging.INFO)

MODEL_STANDARDIZATION = {
"openai/gpt-4-turbo-2024-04-09": "GPT-4 Turbo",
"openai/gpt-4o-2024-05-13": "GPT-4o",
"google/gemma-2-9b-it": "Gemma 2 9B",
"LxzGordon/URM-LLaMa-3.1-8B": "URM LlaMa 3.1 8B",
"meta-llama/Meta-Llama-3.1-70B-Instruct": "Llama 3.1 70B",
"meta-llama/Meta-Llama-3-70B-Instruct": "Llama 3 70B",
"CIR-AMS/BTRM_Qwen2_7b_0613": "BTRM Qwen 2 7B",
"cohere/command-r-plus-08-2024": "Command R+",
"allenai/tulu-2-dpo-13b": "Tulu 2 13B DPO",
"cohere/c4ai-aya-23-35b": "Aya 23 35B",
}

LANG_STANDARDIZATION = {
"arb": "ar",
"ces": "cs",
"deu": "de",
"ell": "el",
"fra": "fr",
"heb": "he",
"hin": "hi",
"ind": "id",
"ita": "it",
"jpn": "jp",
"kor": "kr",
"nld": "nl",
"pes": "fa",
"pol": "pl",
"por": "pt",
"ron": "ro",
"rus": "ru",
"spa": "es",
"tur": "tr",
"ukr": "uk",
"vie": "vi",
"zho": "zh",
}


def get_args():
# fmt: off
Expand All @@ -45,6 +83,11 @@ def get_args():
parser_eng_drop = subparsers.add_parser("eng_drop_line", help="Plot english drop as a line chart.", parents=[shared_args])
parser_eng_drop.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
parser_eng_drop.add_argument("--top_n", default=None, type=int, help="If set, will only show the .")

parser_ling_dims = subparsers.add_parser("ling_dims", help="Plot performance with respect to linguistic dimensions.", parents=[shared_args])
parser_ling_dims.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
parser_ling_dims.add_argument("--langdata", type=Path, required=True, help="Path to the language data file.")
parser_ling_dims.add_argument("--top_n", type=int, required=False, default=None, help="Aggregate only the scores for top-n.")
# fmt: on
return parser.parse_args()

Expand All @@ -55,6 +98,7 @@ def main():
cmd_map = {
"main_heatmap": plot_main_heatmap,
"eng_drop_line": plot_eng_drop_line,
"ling_dims": plot_ling_dims,
}

def _filter_args(func, kwargs):
Expand All @@ -81,16 +125,18 @@ def plot_main_heatmap(

df = df.sort_values(by="Avg_Multilingual", ascending=False).head(10).reset_index(drop=True)
data = df[[col for col in df.columns if col not in ["Model_Type"]]].rename(columns={"Avg_Multilingual": "Avg"})
data["Model"] = data["Model"].replace(MODEL_STANDARDIZATION)
data = data.set_index("Model")
data = data * 100
data["zho"] = data[["zho_Hans", "zho_Hant"]].mean(axis=1)
data.pop("zho_Hans")
data.pop("zho_Hant")
data = data[sorted(data.columns)]
data.columns = [col.split("_")[0] for col in data.columns]
data = data.rename(columns=LANG_STANDARDIZATION)

fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 14}, fmt=".2f", cbar=False)
sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16}, fmt=".2f", cbar=False)
ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="x")
ax.set_ylabel("")
Expand Down Expand Up @@ -141,7 +187,7 @@ def plot_eng_drop_line(
ax.set_aspect("equal")
ax.legend(frameon=False, handletextpad=0.2, fontsize=12)

model_names = [model.split("/")[1] for model in data.index]
model_names = [MODEL_STANDARDIZATION[model] for model in data.index]
texts = [
ax.text(
rewardbench_scores[idx],
Expand Down Expand Up @@ -182,5 +228,52 @@ def plot_eng_drop_line(
print(delta_df.to_latex())


def plot_ling_dims(
input_path: Path,
langdata: Path,
output_path: Path,
top_n: Optional[int] = None,
figsize: Optional[tuple[int, int]] = (18, 5),
):
raw = pd.read_csv(input_path).set_index("Model")
if top_n:
raw = raw.head(top_n)
raw = raw[[col for col in raw.columns if col not in ("Model_Type", "eng_Latn", "Avg_Multilingual")]]
raw = raw.T
langdata = pd.read_csv(langdata).set_index("Language")
combined = raw.merge(langdata, left_index=True, right_index=True)
combined["Avg"] = raw.mean(axis=1) * 100
combined["Std"] = raw.std(axis=1) * 100

combined = combined.rename(columns={"Resource_Type": "Resource Availability"})
linguistic_dims = [
"Resource Availability",
"Family",
"Script",
]
fig, axs = plt.subplots(1, len(linguistic_dims), figsize=figsize, sharex=True)
for ax, dim in zip(axs, linguistic_dims):
lingdf = combined.groupby(dim).agg({"Avg": "mean", "Std": "mean"}).reset_index()

sns.barplot(
x="Avg",
y=dim,
data=lingdf,
ax=ax,
color="green",
width=0.5 if dim == "Resource Availability" else 0.7,
)
ax.set_title(dim)
ax.set_xlim([60, 70])
ax.set_ylabel("")
ax.set_xlabel("M-RewardBench Score")

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")


if __name__ == "__main__":
main()

0 comments on commit 454c66a

Please sign in to comment.