Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor fixes: loading models from wandb, update some deps #198

Merged
merged 7 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import wandb
from maze_dataset import MazeDatasetConfig
from muutils.misc import shorten_numerical_to_str
from muutils.misc import sanitize_fname, shorten_numerical_to_str
from transformer_lens import HookedTransformer
from wandb.sdk.wandb_run import Artifact, Run

Expand All @@ -16,14 +16,20 @@
)


def get_step(artifact: Artifact) -> int:
# Find the alias beginning with "step="
def get_step(
artifact: Artifact, step_prefix: str = "step=", except_if_invalid: bool = False
) -> int:
step_alias: list[str] = [
alias for alias in artifact.aliases if alias.startswith("step=")
alias for alias in artifact.aliases if alias.startswith(step_prefix)
]
if len(step_alias) != 1: # if we have multiple, skip as well
return -1
return int(step_alias[0].split("=")[-1])
if except_if_invalid:
raise KeyError(
f"Could not find step alias in {artifact.name} " f"{artifact.aliases}",
)
else:
return -1
return int(step_alias[0].replace(step_prefix, ""))


def load_model(
Expand All @@ -44,11 +50,43 @@ def load_model(
return model


def match_checkpoint(
checkpoint: int | None,
run: Run,
step_prefix: str = "step=",
) -> Artifact:
# Match checkpoint
# available_checkpoints = [
# artifact for artifact in run.logged_artifacts() if artifact.type == "model"
# ]
available_checkpoints: list[Artifact] = list(run.logged_artifacts())
artifact: list[Artifact] = [
aft for aft in available_checkpoints if get_step(aft, step_prefix) == checkpoint
]
if len(artifact) != 1:
raise KeyError(
f"Could not find checkpoint {checkpoint} in {run.name} "
f"Available checkpoints: ",
str(
[
f"{artifact.name} | steps: {get_step(artifact, step_prefix)}"
for artifact in available_checkpoints
]
),
"\n",
str([(x.name, x.aliases) for x in available_checkpoints]),
)

artifact = artifact[0]
print(f"Loading checkpoint {checkpoint}")
return artifact


def load_wandb_run(
project="aisc-search/alex",
run_id="sa973hyn",
output_path="./downloaded_models",
checkpoint=None,
project: str = "aisc-search/alex",
run_id: str = "sa973hyn",
output_path: str = "./downloaded_models",
checkpoint: int | None = None,
) -> tuple[HookedTransformer, ConfigHolder]:
api: wandb.Api = wandb.Api()

Expand All @@ -58,24 +96,9 @@ def load_wandb_run(
wandb_cfg: wandb.config.Config = run.config # Get run configuration

# -- Get / Match checkpoint --
artifact: Artifact
if checkpoint is not None:
# Match checkpoint
available_checkpoints = [
artifact for artifact in run.logged_artifacts() if artifact.type == "model"
]
available_checkpoints = list(run.logged_artifacts())
artifact = [aft for aft in available_checkpoints if get_step(aft) == checkpoint]
if len(artifact) != 1:
print(f"Could not find checkpoint {checkpoint} in {artifact_name}")
print("Available checkpoints:")
[
print(artifact.name, "| Steps: ", get_step(artifact))
for artifact in available_checkpoints
]
return

artifact = artifact[0]
print("Loading checkpoint", checkpoint)
artifact = match_checkpoint(checkpoint, run)
else:
# Get latest checkpoint
print("Loading latest checkpoint")
Expand Down Expand Up @@ -170,3 +193,70 @@ def load_wandb_pt_model_as_zanj(
print(f"Saved model to {model_zanj_save_path.as_posix()}")

return model_zanj


def load_wandb_zanj(
run_id: str,
project: str = "aisc-search/alex",
checkpoint: int | None = None,
output_path: str | Path = "./downloaded_models",
model: bool = True,
) -> ZanjHookedTransformer:
output_path = Path(output_path)
api: wandb.Api = wandb.Api()
artifact_name: str = f"{project.rstrip('/')}/{run_id}"
run: Run = api.run(artifact_name)
wandb_cfg: wandb.config.Config = run.config # Get run configuration

print(f"Get artifact from {artifact_name = } corresponding to {checkpoint = }")
artifact: Artifact
is_final: bool
if checkpoint is not None:
is_final = False
artifact = match_checkpoint(checkpoint, run, step_prefix="iter-")
else:
is_final = True
# Get latest checkpoint
print("Loading latest checkpoint")
artifact_name = f"{artifact_name}:latest"
artifact = api.artifact(artifact_name)
checkpoint = get_step(artifact, step_prefix="iter-")
print(f"Found checkpoint {checkpoint}, {artifact_name}")

artifact_name_sanitized: str = sanitize_fname(artifact.name.replace(":", "-"))
print(f"download model {artifact_name_sanitized = }")

download_dir: Path = output_path / "temp" / artifact_name_sanitized

artifact.download(root=download_dir)
# get the single .zanj file in the download dir
download_path = next(download_dir.glob("*.zanj"))
print(f"\tDownloaded model to '{download_path}'")

print(f"load and re-save model with better filename and more data")
print(f"\tLoading model from '{download_path}'")
model: ZanjHookedTransformer = ZanjHookedTransformer.read(download_path)

# add metadata to model training records
model.training_records.update(
dict(
run_id=run_id,
project=project,
checkpoint=checkpoint,
is_final=is_final,
original_download_path=download_path,
version=artifact.name.split(":")[-1],
)
)

# save as proper model name
updated_save_path: Path = output_path / (
f"model.{artifact_name_sanitized}.iter_{checkpoint}.zanj"
if not is_final
else f"model.{artifact_name_sanitized}.final.zanj"
)
model.save(updated_save_path)

print(f"\tSaved model to '{updated_save_path}'")

return model
75 changes: 75 additions & 0 deletions notebooks/get_wandb_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from maze_transformer.training.config import ZanjHookedTransformer\n",
"from maze_transformer.utils.get_wandb_models import load_wandb_pt_model_as_zanj, load_wandb_zanj"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"uncomment the code below to get a wandb pytorch model (in @afspies format) to a zanj pytorch model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# MODEL: ZanjHookedTransformer = load_wandb_pt_model_as_zanj(\n",
"# \tproject=\"aisc-search/alex\", \n",
"# \trun_id=\"jerpkipj\", \n",
"# \tcheckpoint=None,\n",
"# \tsave_zanj_model=True,\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load_wandb_zanj(\n",
"# run_id=\"1n570yl5\",\n",
"# \tproject=\"aisc-search/understanding-search\",\n",
"# \tcheckpoint=None,\n",
"# \toutput_path=\"./downloaded_models\",\n",
"# \tmodel=True,\n",
"# )\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "maze-transformer-2cGx2R0F-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
87 changes: 0 additions & 87 deletions notebooks/wandb_to_zanj.ipynb

This file was deleted.

Loading