Skip to content

Commit

Permalink
run formatters
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Mar 5, 2024
1 parent b1abc37 commit 1a961bb
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 35 deletions.
3 changes: 2 additions & 1 deletion maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __call__(
maze: LatticeMaze | None = None,
solution: CoordArray | None = None,
prediction: CoordArray | None = None,
) -> float: ...
) -> float:
...


def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]:
Expand Down
14 changes: 7 additions & 7 deletions maze_transformer/mechinterp/direct_logit_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def plot_direct_logit_attribution(
by default, its the identity map for contribs:
`layer_index_normalization: typing.Callable[[float, int], float]|None = lambda contrib, layer_idx: contrib`
"""
dla_data: dict[str, Float[np.ndarray, "layer head/neuron"]] = (
compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
)
dla_data: dict[
str, Float[np.ndarray, "layer head/neuron"]
] = compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
)
if layer_index_normalization is not None:
dla_data = {
Expand Down
6 changes: 4 additions & 2 deletions maze_transformer/mechinterp/logit_attrib_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def get_token_first_index(search_token: str, token_list: list[str]) -> int:
class DLAProtocol(typing.Protocol):
"""should take a dataset's tokens, and return a tuple of (prompts, targets)"""

def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup: ...
def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup:
...


class DLAProtocolFixed(typing.Protocol):
Expand All @@ -31,7 +32,8 @@ class DLAProtocolFixed(typing.Protocol):
this variant signifies it's ready to be used -- no keyword arguments are needed
"""

def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup: ...
def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup:
...


def token_after_fixed_start_token(
Expand Down
18 changes: 9 additions & 9 deletions maze_transformer/mechinterp/logit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def logit_diff_residual_stream(
vocab_tensor: Float[torch.Tensor, "d_vocab"] = torch.arange(
d_vocab, dtype=torch.long
)
vocab_residual_directions: Float[torch.Tensor, "d_vocab d_model"] = (
model.tokens_to_residual_directions(vocab_tensor)
)
vocab_residual_directions: Float[
torch.Tensor, "d_vocab d_model"
] = model.tokens_to_residual_directions(vocab_tensor)
# get embedding of answer tokens
answer_residual_directions = vocab_residual_directions[tokens_correct]
# get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens}
Expand All @@ -108,12 +108,12 @@ def logit_diff_residual_stream(
][:, -1, :]

# scaling the values in residual stream with layer norm
scaled_final_token_residual_stream: Float[torch.Tensor, "samples d_model"] = (
cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
)
scaled_final_token_residual_stream: Float[
torch.Tensor, "samples d_model"
] = cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
)

# measure similarity between the logit diff directions and the residual stream at final layer directions
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def compute_distances_and_correlation(
# embedding_distances /= embedding_distances.max()

# Convert the distances to a square matrix
embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"] = (
squareform(embedding_distances)
)
embedding_distances_matrix: Float[
np.ndarray, "n_coord_tokens n_coord_tokens"
] = squareform(embedding_distances)

# Calculate the correlation between the embedding and coordinate distances
coordinate_coordinates: Float[np.ndarray, "n_coord_tokens 2"] = np.array(
Expand Down
13 changes: 7 additions & 6 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ def get_intervals(
)

except ValueError as e:
_debug_vals: str = (
f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
)
_debug_vals: str = f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
raise ValueError(f"{_debug_vals}\ntriggered error:\n{e}") from e

# disable if set to 0 or negative
Expand Down Expand Up @@ -660,9 +658,12 @@ def _load_state_dict_wrapper(
self.zanj_model_config.model_cfg.weight_processing["are_layernorms_folded"]
or fold_ln
)
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"] = (
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"]
or (not recover_exact)
self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] = self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] or (
not recover_exact
)

self.load_and_process_state_dict(
Expand Down
10 changes: 6 additions & 4 deletions maze_transformer/training/train_save_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ class TRAIN_SAVE_FILES:
config_holder: str = "config.json"
checkpoints: str = "checkpoints"
log: str = "log.jsonl"
model_checkpt_zanj: Callable[[int], str] = (
lambda iteration: f"model.iter_{iteration}.zanj"
)
model_checkpt_zanj: Callable[
[int], str
] = lambda iteration: f"model.iter_{iteration}.zanj"
model_final_zanj: str = "model.final.zanj"
model_run_dir: Callable[[ConfigHolder], str] = (
model_run_dir: Callable[
[ConfigHolder], str
] = (
lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
6 changes: 3 additions & 3 deletions tests/unit/maze_transformer/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def test_tokenization_encoding(
)
def test_to_ascii(tok_mode):
# Check that the ascii encoding works for multiple different inputs
maze_str_tokens: list[str] = (
"""<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ;
maze_str_tokens: list[
str
] = """<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ;
(2,2) <--> (2,1) ; (2,0) <--> (2,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; (0,2) <--> (0,1) ;
<ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>""".split()
)

target: list[str] = [
"#######",
Expand Down

0 comments on commit 1a961bb

Please sign in to comment.