Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Aug 6, 2023
1 parent 144daa8 commit 16bc51e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
4 changes: 3 additions & 1 deletion maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def evaluate_model(
}

if dataset_tokens is None:
dataset_tokens = dataset.as_tokens(model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False)
dataset_tokens = dataset.as_tokens(
model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False
)
else:
assert len(dataset) == len(
dataset_tokens
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from circuitsvis.tokens import colored_tokens_multi
from jaxtyping import Float
from maze_dataset import CoordTup, MazeDataset, MazeDatasetConfig, SolvedMaze
from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable
from maze_dataset.plotting import MazePlot
from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable

# Utilities
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
Expand Down
3 changes: 1 addition & 2 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def train(
evals_enabled = False

val_dataset_tokens: list[list[str]] = val_dataset.as_tokens(
model.zanj_model_config.maze_tokenizer,
join_tokens_individual_maze=False
model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False
)

# compute intervals
Expand Down

0 comments on commit 16bc51e

Please sign in to comment.