From 16bc51ec329ff364225bc8c1364aeffd4db14576 Mon Sep 17 00:00:00 2001 From: mivanit Date: Sun, 6 Aug 2023 17:24:23 -0400 Subject: [PATCH] format --- maze_transformer/evaluation/eval_model.py | 4 +++- maze_transformer/evaluation/plot_attention.py | 2 +- maze_transformer/training/training.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/maze_transformer/evaluation/eval_model.py b/maze_transformer/evaluation/eval_model.py index 3dd0565c..21a02e21 100644 --- a/maze_transformer/evaluation/eval_model.py +++ b/maze_transformer/evaluation/eval_model.py @@ -208,7 +208,9 @@ def evaluate_model( } if dataset_tokens is None: - dataset_tokens = dataset.as_tokens(model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False) + dataset_tokens = dataset.as_tokens( + model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False + ) else: assert len(dataset) == len( dataset_tokens diff --git a/maze_transformer/evaluation/plot_attention.py b/maze_transformer/evaluation/plot_attention.py index bb0f287f..1910643d 100644 --- a/maze_transformer/evaluation/plot_attention.py +++ b/maze_transformer/evaluation/plot_attention.py @@ -15,8 +15,8 @@ from circuitsvis.tokens import colored_tokens_multi from jaxtyping import Float from maze_dataset import CoordTup, MazeDataset, MazeDatasetConfig, SolvedMaze -from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable from maze_dataset.plotting import MazePlot +from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable # Utilities from muutils.json_serialize import SerializableDataclass, serializable_dataclass diff --git a/maze_transformer/training/training.py b/maze_transformer/training/training.py index f9c77184..63570b3f 100644 --- a/maze_transformer/training/training.py +++ b/maze_transformer/training/training.py @@ -101,8 +101,7 @@ def train( evals_enabled = False val_dataset_tokens: list[list[str]] = val_dataset.as_tokens( - model.zanj_model_config.maze_tokenizer, - join_tokens_individual_maze=False + model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False ) # compute intervals