Skip to content

Commit

Permalink
minor fixes: loading models from wandb, update some deps (#198)
Browse files Browse the repository at this point in the history
* bump muutils, transformer_lens, and maze-dataset
* organize deps
* downloading wandb zanj models
* reorg wandb downloading notebooks/modules
* forcing huggingface transformers to 4.33.3 due to interface change
  • Loading branch information
mivanit authored Oct 5, 2023
1 parent a08bfd7 commit 9330d9a
Show file tree
Hide file tree
Showing 5 changed files with 645 additions and 518 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import wandb
from maze_dataset import MazeDatasetConfig
from muutils.misc import shorten_numerical_to_str
from muutils.misc import sanitize_fname, shorten_numerical_to_str
from transformer_lens import HookedTransformer
from wandb.sdk.wandb_run import Artifact, Run

Expand All @@ -16,14 +16,20 @@
)


def get_step(artifact: Artifact) -> int:
# Find the alias beginning with "step="
def get_step(
artifact: Artifact, step_prefix: str = "step=", except_if_invalid: bool = False
) -> int:
step_alias: list[str] = [
alias for alias in artifact.aliases if alias.startswith("step=")
alias for alias in artifact.aliases if alias.startswith(step_prefix)
]
if len(step_alias) != 1: # if we have multiple, skip as well
return -1
return int(step_alias[0].split("=")[-1])
if except_if_invalid:
raise KeyError(
f"Could not find step alias in {artifact.name} " f"{artifact.aliases}",
)
else:
return -1
return int(step_alias[0].replace(step_prefix, ""))


def load_model(
Expand All @@ -44,11 +50,43 @@ def load_model(
return model


def match_checkpoint(
checkpoint: int | None,
run: Run,
step_prefix: str = "step=",
) -> Artifact:
# Match checkpoint
# available_checkpoints = [
# artifact for artifact in run.logged_artifacts() if artifact.type == "model"
# ]
available_checkpoints: list[Artifact] = list(run.logged_artifacts())
artifact: list[Artifact] = [
aft for aft in available_checkpoints if get_step(aft, step_prefix) == checkpoint
]
if len(artifact) != 1:
raise KeyError(
f"Could not find checkpoint {checkpoint} in {run.name} "
f"Available checkpoints: ",
str(
[
f"{artifact.name} | steps: {get_step(artifact, step_prefix)}"
for artifact in available_checkpoints
]
),
"\n",
str([(x.name, x.aliases) for x in available_checkpoints]),
)

artifact = artifact[0]
print(f"Loading checkpoint {checkpoint}")
return artifact


def load_wandb_run(
project="aisc-search/alex",
run_id="sa973hyn",
output_path="./downloaded_models",
checkpoint=None,
project: str = "aisc-search/alex",
run_id: str = "sa973hyn",
output_path: str = "./downloaded_models",
checkpoint: int | None = None,
) -> tuple[HookedTransformer, ConfigHolder]:
api: wandb.Api = wandb.Api()

Expand All @@ -58,24 +96,9 @@ def load_wandb_run(
wandb_cfg: wandb.config.Config = run.config # Get run configuration

# -- Get / Match checkpoint --
artifact: Artifact
if checkpoint is not None:
# Match checkpoint
available_checkpoints = [
artifact for artifact in run.logged_artifacts() if artifact.type == "model"
]
available_checkpoints = list(run.logged_artifacts())
artifact = [aft for aft in available_checkpoints if get_step(aft) == checkpoint]
if len(artifact) != 1:
print(f"Could not find checkpoint {checkpoint} in {artifact_name}")
print("Available checkpoints:")
[
print(artifact.name, "| Steps: ", get_step(artifact))
for artifact in available_checkpoints
]
return

artifact = artifact[0]
print("Loading checkpoint", checkpoint)
artifact = match_checkpoint(checkpoint, run)
else:
# Get latest checkpoint
print("Loading latest checkpoint")
Expand Down Expand Up @@ -170,3 +193,70 @@ def load_wandb_pt_model_as_zanj(
print(f"Saved model to {model_zanj_save_path.as_posix()}")

return model_zanj


def load_wandb_zanj(
run_id: str,
project: str = "aisc-search/alex",
checkpoint: int | None = None,
output_path: str | Path = "./downloaded_models",
model: bool = True,
) -> ZanjHookedTransformer:
output_path = Path(output_path)
api: wandb.Api = wandb.Api()
artifact_name: str = f"{project.rstrip('/')}/{run_id}"
run: Run = api.run(artifact_name)
wandb_cfg: wandb.config.Config = run.config # Get run configuration

print(f"Get artifact from {artifact_name = } corresponding to {checkpoint = }")
artifact: Artifact
is_final: bool
if checkpoint is not None:
is_final = False
artifact = match_checkpoint(checkpoint, run, step_prefix="iter-")
else:
is_final = True
# Get latest checkpoint
print("Loading latest checkpoint")
artifact_name = f"{artifact_name}:latest"
artifact = api.artifact(artifact_name)
checkpoint = get_step(artifact, step_prefix="iter-")
print(f"Found checkpoint {checkpoint}, {artifact_name}")

artifact_name_sanitized: str = sanitize_fname(artifact.name.replace(":", "-"))
print(f"download model {artifact_name_sanitized = }")

download_dir: Path = output_path / "temp" / artifact_name_sanitized

artifact.download(root=download_dir)
# get the single .zanj file in the download dir
download_path = next(download_dir.glob("*.zanj"))
print(f"\tDownloaded model to '{download_path}'")

print(f"load and re-save model with better filename and more data")
print(f"\tLoading model from '{download_path}'")
model: ZanjHookedTransformer = ZanjHookedTransformer.read(download_path)

# add metadata to model training records
model.training_records.update(
dict(
run_id=run_id,
project=project,
checkpoint=checkpoint,
is_final=is_final,
original_download_path=download_path,
version=artifact.name.split(":")[-1],
)
)

# save as proper model name
updated_save_path: Path = output_path / (
f"model.{artifact_name_sanitized}.iter_{checkpoint}.zanj"
if not is_final
else f"model.{artifact_name_sanitized}.final.zanj"
)
model.save(updated_save_path)

print(f"\tSaved model to '{updated_save_path}'")

return model
75 changes: 75 additions & 0 deletions notebooks/get_wandb_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from maze_transformer.training.config import ZanjHookedTransformer\n",
"from maze_transformer.utils.get_wandb_models import load_wandb_pt_model_as_zanj, load_wandb_zanj"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"uncomment the code below to get a wandb pytorch model (in @afspies format) to a zanj pytorch model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# MODEL: ZanjHookedTransformer = load_wandb_pt_model_as_zanj(\n",
"# \tproject=\"aisc-search/alex\", \n",
"# \trun_id=\"jerpkipj\", \n",
"# \tcheckpoint=None,\n",
"# \tsave_zanj_model=True,\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load_wandb_zanj(\n",
"# run_id=\"1n570yl5\",\n",
"# \tproject=\"aisc-search/understanding-search\",\n",
"# \tcheckpoint=None,\n",
"# \toutput_path=\"./downloaded_models\",\n",
"# \tmodel=True,\n",
"# )\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "maze-transformer-2cGx2R0F-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
87 changes: 0 additions & 87 deletions notebooks/wandb_to_zanj.ipynb

This file was deleted.

Loading

0 comments on commit 9330d9a

Please sign in to comment.