Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Aug 3, 2023
1 parent 3355c24 commit 1142760
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 75 deletions.
7 changes: 4 additions & 3 deletions maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
SolvedMaze,
)
from maze_dataset.tokenization.token_utils import (
strings_to_coords,
coords_to_strings,
get_origin_tokens,
get_path_tokens,
get_target_tokens,
strings_to_coords,
)
from transformer_lens import HookedTransformer

Expand Down Expand Up @@ -105,7 +104,9 @@ def _generate_path(
steps_to_predict: int,
) -> list[str]:
# assemble the maze from the tokens
maze: LatticeMaze = LatticeMaze.from_tokens(tokens, self.tokenizer._maze_tokenizer)
maze: LatticeMaze = LatticeMaze.from_tokens(
tokens, self.tokenizer._maze_tokenizer
)
origin_coord: CoordTup = strings_to_coords(get_origin_tokens(tokens))[0]
target_coord: CoordTup = strings_to_coords(get_target_tokens(tokens))[0]
solution: CoordArray = maze.find_shortest_path(origin_coord, target_coord)
Expand Down
4 changes: 1 addition & 3 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@ def evaluate_logits(
for tokens in prediction_tokens:
# this returns first path_start to end of list. Early in training there may be multiple path_start tokens, so results should be treated with caution
path_tokens = get_path_tokens(tokens.split(" "))
path_coords = strings_to_coords(
path_tokens, when_noncoord="skip"
)
path_coords = strings_to_coords(path_tokens, when_noncoord="skip")
predicted_paths.append(cast(list[tuple[int, int]], path_coords))

maze_tokens = [
Expand Down
4 changes: 3 additions & 1 deletion maze_transformer/evaluation/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def from_model_and_dataset(
for i in range(n_mazes):
# get the maze from the dataset and process into tokens
solved_maze: SolvedMaze = dataset[i]
tokens: list[str] = solved_maze.as_tokens(model.zanj_model_config.maze_tokenizer)
tokens: list[str] = solved_maze.as_tokens(
model.zanj_model_config.maze_tokenizer
)
tokens_context: list[str]

if context_maze_only:
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/test_helpers/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import Int
from zanj.torchutil import ConfigMismatchException, assert_model_cfg_equality

from maze_transformer.training.config import BaseGPTConfig, ZanjHookedTransformer
from maze_transformer.training.config import ZanjHookedTransformer


def _check_except_config_equality_modulo_weight_processing(
Expand Down
20 changes: 12 additions & 8 deletions maze_transformer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# Avoid circular import from training/config.py
from typing import TYPE_CHECKING, Sequence, Union # need Union as "a" | "b" doesn't work
from typing import TYPE_CHECKING, Sequence # need Union as "a" | "b" doesn't work

import torch
from maze_dataset import SPECIAL_TOKENS, LatticeMaze
from maze_dataset.dataset.dataset import GPTDatasetConfig
from maze_dataset.plotting import MazePlot
from maze_dataset.tokenization import MazeTokenizer
from muutils.tensor_utils import ATensor, NDArray
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import BatchEncoding

from maze_dataset.tokenization import MazeTokenizer

if TYPE_CHECKING:
from maze_transformer.training.config import ConfigHolder
pass

# pylint: disable=unused-import, abstract-method

Expand Down Expand Up @@ -53,9 +51,15 @@ def __init__(
self.vocab_size = self._vocab_size
self._tokenizer_map = maze_tokenizer.tokenizer_map

assert isinstance(seq_len_max, int), f"seq_len_max must be an int, got {seq_len_max = } {type(seq_len_max) = }"
assert isinstance(token_arr, Sequence), f"token_arr must be a Sequence, got {token_arr = } {type(token_arr) = }"
assert isinstance(len(token_arr), int), f"token_arr must be a Sequence, got {token_arr = } {type(token_arr) = }"
assert isinstance(
seq_len_max, int
), f"seq_len_max must be an int, got {seq_len_max = } {type(seq_len_max) = }"
assert isinstance(
token_arr, Sequence
), f"token_arr must be a Sequence, got {token_arr = } {type(token_arr) = }"
assert isinstance(
len(token_arr), int
), f"token_arr must be a Sequence, got {token_arr = } {type(token_arr) = }"

# We are having to do evil things here
vocab: dict[str, int] = {token: i for i, token in enumerate(token_arr)}
Expand Down
20 changes: 10 additions & 10 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
import math
import typing
import warnings
from functools import cached_property
Expand All @@ -11,6 +10,7 @@
import torch
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.dataset.dataset import GPTDatasetConfig
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from muutils.dictmagic import kwargs_to_nested_dict
from muutils.json_serialize import (
JSONitem,
Expand All @@ -25,8 +25,6 @@
from zanj.loading import load_item_recursive
from zanj.torchutil import ConfiguredModel, set_config_class

from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_dataset.constants import SPECIAL_TOKENS
from maze_transformer.tokenizer import HuggingMazeTokenizer


Expand Down Expand Up @@ -369,6 +367,7 @@ def summary(self) -> dict:
cfg.name: cfg for cfg in _TRAINING_CONFIG_LIST
}


def _load_maze_tokenizer(data: dict) -> MazeTokenizer:
"""load the maze tokenizer, including vocab size from a legacy config"""
if "maze_tokenizer" in data:
Expand All @@ -381,9 +380,8 @@ def _load_maze_tokenizer(data: dict) -> MazeTokenizer:
max_grid_size=None,
)
else:
raise ValueError(
"Could not find vocab size in legacy config"
)
raise ValueError("Could not find vocab size in legacy config")


@serializable_dataclass(kw_only=True)
class ConfigHolder(SerializableDataclass):
Expand All @@ -405,8 +403,8 @@ class ConfigHolder(SerializableDataclass):
pretrainedtokenizer_kwargs: dict[str, JSONitem] | None = serializable_field(
default=None
)
maze_tokenizer: MazeTokenizer|None = serializable_field(
default_factory=lambda : None,
maze_tokenizer: MazeTokenizer | None = serializable_field(
default_factory=lambda: None,
loading_fn=_load_maze_tokenizer,
)

Expand All @@ -419,7 +417,7 @@ def __post_init__(self):
if self.pretrainedtokenizer_kwargs is None:
if self.maze_tokenizer is None:
self.maze_tokenizer = MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_UT_uniform,
tokenization_mode=TokenizationMode.AOTP_UT_uniform,
max_grid_size=None,
)

Expand All @@ -436,7 +434,9 @@ def summary(self) -> str:
"model_cfg": self.model_cfg.summary(),
"train_cfg": self.train_cfg.summary(),
"pretrainedtokenizer_kwargs": self.pretrainedtokenizer_kwargs,
"maze_tokenizer": self.maze_tokenizer.summary() if self.maze_tokenizer is not None else None,
"maze_tokenizer": self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None,
}

@property
Expand Down
4 changes: 2 additions & 2 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
from jaxtyping import Float
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_dataset import MazeDataset, SolvedMaze
from maze_dataset.tokenization import MazeTokenizer
from muutils.statcounter import StatCounter
from torch.utils.data import DataLoader
from transformer_lens.HookedTransformer import SingleLoss
Expand Down
77 changes: 60 additions & 17 deletions tests/unit/maze_transformer/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,40 @@
We may want a separate set of tests for different tokenization schemes
"""
import sys
from itertools import product

import torch
from maze_dataset import MazeDatasetConfig, SolvedMaze
from maze_dataset.generation import get_maze_with_solution
import pytest
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from pytest import mark, param
from transformer_lens import HookedTransformer

from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_transformer.training.config import BaseGPTConfig, ConfigHolder


@mark.parametrize(
"tok_mode,grid_size,grid_size_max",
"tok_mode,grid_size,grid_size_max",
[
param(tok_mode, grid_size, grid_size_max, id=f"{tok_mode.name.split('_')[-1]},g{grid_size},m{grid_size_max}")
param(
tok_mode,
grid_size,
grid_size_max,
id=f"{tok_mode.name.split('_')[-1]},g{grid_size},m{grid_size_max}",
)
for tok_mode, grid_size, grid_size_max in product(
TokenizationMode, [3, 4], [3, 4, 5, 6, 10, 50]
)
],
)
def test_tokenization_encoding(tok_mode: TokenizationMode, grid_size: int, grid_size_max: int):
def test_tokenization_encoding(
tok_mode: TokenizationMode, grid_size: int, grid_size_max: int
):
# create maze and tokenizer
solved_maze: SolvedMaze = get_maze_with_solution("gen_dfs", (3, 3))
tok: MazeTokenizer = MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=grid_size)
tok: MazeTokenizer = MazeTokenizer(
tokenization_mode=tok_mode, max_grid_size=grid_size
)

# convert to strings
maze_str_tokens: list[str] = solved_maze.as_tokens(tok)
Expand Down Expand Up @@ -66,12 +73,17 @@ def test_tokenization_encoding(tok_mode: TokenizationMode, grid_size: int, grid_
"tok_mode",
[
param(tok_mode, id=tok_mode.name)
for tok_mode in [TokenizationMode.AOTP_UT_uniform, TokenizationMode.AOTP_UT_rasterized]
for tok_mode in [
TokenizationMode.AOTP_UT_uniform,
TokenizationMode.AOTP_UT_rasterized,
]
],
)
def test_to_ascii(tok_mode):
# Check that the ascii encoding works for multiple different inputs
maze_str_tokens: list[str] = """<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ;
maze_str_tokens: list[
str
] = """<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ;
(2,2) <--> (2,1) ; (2,0) <--> (2,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; (0,2) <--> (0,1) ;
<ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>""".split()

Expand Down Expand Up @@ -106,7 +118,7 @@ def test_to_ascii(tok_mode):


@mark.parametrize(
"tok_mode",
"tok_mode",
[
param(TokenizationMode.AOTP_UT_uniform, id="AOTP_UT_uniform"),
param(TokenizationMode.AOTP_UT_rasterized, id="AOTP_UT_rasterized"),
Expand Down Expand Up @@ -181,15 +193,46 @@ def test_tokenizer_inside_hooked_transformer(tok_mode):
# Padding Tests
PAD_PLACEHOLDER = -1


@mark.parametrize(
"inp,expected,tok_mode",
"inp,expected,tok_mode",
[
param([1, 2, 3], [PAD_PLACEHOLDER, PAD_PLACEHOLDER, 1, 2, 3], TokenizationMode.AOTP_UT_uniform, id="short+uniform"),
param([1, 2, 3, 4, 5], [1, 2, 3, 4, 5], TokenizationMode.AOTP_UT_uniform, id="max_length+uniform"),
param([1, 2, 3, 4, 5, 6], [2, 3, 4, 5, 6], TokenizationMode.AOTP_UT_uniform, id="too_long+uniform"),
param([1, 2, 3], [PAD_PLACEHOLDER, PAD_PLACEHOLDER, 1, 2, 3], TokenizationMode.AOTP_UT_rasterized, id="short+rasterized"),
param([1, 2, 3, 4, 5], [1, 2, 3, 4, 5], TokenizationMode.AOTP_UT_rasterized, id="max_length+rasterized"),
param([1, 2, 3, 4, 5, 6], [2, 3, 4, 5, 6], TokenizationMode.AOTP_UT_rasterized, id="too_long+rasterized"),
param(
[1, 2, 3],
[PAD_PLACEHOLDER, PAD_PLACEHOLDER, 1, 2, 3],
TokenizationMode.AOTP_UT_uniform,
id="short+uniform",
),
param(
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
TokenizationMode.AOTP_UT_uniform,
id="max_length+uniform",
),
param(
[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6],
TokenizationMode.AOTP_UT_uniform,
id="too_long+uniform",
),
param(
[1, 2, 3],
[PAD_PLACEHOLDER, PAD_PLACEHOLDER, 1, 2, 3],
TokenizationMode.AOTP_UT_rasterized,
id="short+rasterized",
),
param(
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
TokenizationMode.AOTP_UT_rasterized,
id="max_length+rasterized",
),
param(
[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6],
TokenizationMode.AOTP_UT_rasterized,
id="too_long+rasterized",
),
],
)
def test_pad_sequence_param(inp, expected, tok_mode):
Expand Down
38 changes: 17 additions & 21 deletions tests/unit/maze_transformer/training/config/test_cfg_post_init.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import json
from pathlib import Path

from maze_dataset import MazeDatasetConfig
from zanj import ZANJ

from maze_dataset.tokenization import MazeTokenizer

from maze_transformer.training.config import BaseGPTConfig, ConfigHolder, TrainConfig

def test_cfg_post_init():

cfg: ConfigHolder = ConfigHolder(
train_cfg=TrainConfig(name="test_cfg_save-train"),
dataset_cfg=MazeDatasetConfig(name="test_cfg_save-data", grid_n=5, n_mazes=10),
model_cfg=BaseGPTConfig(
name="test_cfg_save-model",
act_fn="dummy-act-fn",
d_model=500,
d_head=60,
n_layers=4,
),
)
def test_cfg_post_init():
cfg: ConfigHolder = ConfigHolder(
train_cfg=TrainConfig(name="test_cfg_save-train"),
dataset_cfg=MazeDatasetConfig(name="test_cfg_save-data", grid_n=5, n_mazes=10),
model_cfg=BaseGPTConfig(
name="test_cfg_save-model",
act_fn="dummy-act-fn",
d_model=500,
d_head=60,
n_layers=4,
),
)

assert isinstance(cfg.maze_tokenizer, MazeTokenizer)
assert isinstance(cfg.maze_tokenizer.max_grid_size, int)
assert cfg.maze_tokenizer.max_grid_size == 5
assert isinstance(cfg.maze_tokenizer.vocab_size, int)
assert isinstance(cfg.maze_tokenizer, MazeTokenizer)
assert isinstance(cfg.maze_tokenizer.max_grid_size, int)
assert cfg.maze_tokenizer.max_grid_size == 5
assert isinstance(cfg.maze_tokenizer.vocab_size, int)
1 change: 0 additions & 1 deletion tests/unit/maze_transformer/training/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
TEMP_DIR.mkdir(parents=True, exist_ok=True)



class TestGPTDataset:
class TestFromConfig:
cfg = MazeDatasetConfig(name="test", grid_n=3, n_mazes=1)
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/maze_transformer/training/test_get_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest

from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode

from maze_transformer.test_helpers.stub_logger import StubLogger
from maze_transformer.training.config import GPT_CONFIGS, TRAINING_CONFIGS, ConfigHolder
from maze_transformer.training.training import get_dataloader


@pytest.mark.parametrize(
"tok_mode",
[
Expand All @@ -32,7 +32,8 @@ def test_get_dataloader(tok_mode: TokenizationMode):

other_batch1 = next(iter(dataloader))
dataloader_mazes = [
SolvedMaze.from_tokens(tokens, config_holder.maze_tokenizer) for tokens in batch1
SolvedMaze.from_tokens(tokens, config_holder.maze_tokenizer)
for tokens in batch1
]

assert len(batch1) == 5
Expand Down
Loading

0 comments on commit 1142760

Please sign in to comment.