Skip to content

Commit

Permalink
fix tokenizer to work with new version of transformers library (#208)
Browse files Browse the repository at this point in the history
various fixes to make tokenizers work with the latest versions of HF `transformers` and `transformer_lens`

# commit history
* Try and fix tokenizer to work with new version of transformers library

The proposed solution is probably not backwards compatible, and is fairly hacky (it strips spaces, and I am not sure it properly assigns vocab / special tokens):

There is an issue with our tokenization in the new version of transformers. In particular, in the tokenize function from transformers.tokenization_utils.py  the line tokens = self.tokens_trie.split(text) returns a list of tokens with spaces if the input sequence is <path_start> (1,0)… (i.e. includes spaces). this wasn’t the case before, and I suspect stems from how I have to change the addition of the vocabulary in our tokenizer (to work with their new way of handling token addition via the _add_tokens method (we can’t just overwrite the dicts as these are now properties >.<). As a temporary fix we can manually remove spaces from sequences, but that’s quite disgusting

The best option might be to create token jsons and push a tokenizer to huggingface.

* Updated poetry dependencies. `poetry.lock` now has `transformers 4.38.1`, `transformer-lens 1.14.0` among many other updates.

* Added `self.init_kwargs["add_bos_token"] = True` as an uninformed band-aid. Need to discuss if this makes any sense.

* Tiny fix to `HuggingMazeTokenizer._tokenize` as described in the Github comment above. One unit test eliminated, other unit tests and notebook tests pass. A few notebooks are dumping their outputs directly to notebooks/ instead of a temp directory. Didn't delete them just for reference by a future fix.

* Unit tests pass, my CPU won't let me run `make test` right now.

* All tests pass

* Updated `black` dependency to match CI version. Reran formatting.

* run formatters

* minor type hint fix

* our special tokens aren't what HF special tokens are

* re-run format??

* improved test_maze_to_tokens_roundtrip, added comparison with manually inspected tokenization

* throw exception on an empty space token

* moved tokenizer test to maze-dataset

* format

---------

Co-authored-by: aaron-sandoval <[email protected]>
Co-authored-by: mivanit <[email protected]>
  • Loading branch information
3 people committed Mar 5, 2024
1 parent b3417f9 commit 495d8d3
Show file tree
Hide file tree
Showing 16 changed files with 3,784 additions and 1,436 deletions.
3 changes: 3 additions & 0 deletions maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def _process_context(
)
elif isinstance(context, str):
tokens = self.tokenizer.tokenize(context)
assert (
"" not in tokens and " " not in tokens
), "Tokenizer error, split `context` includes bad token strings."
else:
raise TypeError(f"Expected list[str], str, or tensor, got {type(context)}")

Expand Down
3 changes: 1 addition & 2 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def __call__(
maze: LatticeMaze | None = None,
solution: CoordArray | None = None,
prediction: CoordArray | None = None,
) -> float:
...
) -> float: ...


def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]:
Expand Down
19 changes: 10 additions & 9 deletions maze_transformer/mechinterp/direct_logit_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,23 @@ def plot_direct_logit_attribution(
answer_tokens: Int[torch.Tensor, "n_mazes"],
do_neurons: bool = False,
show: bool = True,
layer_index_normalization: typing.Callable[[float, int], float]
| None = lambda contrib, layer_idx: contrib,
layer_index_normalization: (
typing.Callable[[float, int], float] | None
) = lambda contrib, layer_idx: contrib,
) -> tuple[plt.Figure, plt.Axes, dict[str, Float[np.ndarray, "layer head/neuron"]]]:
"""compute, process, and plot direct logit attribution
Layer index normalization allows us to process the contribution according to the layer index.
by default, its the identity map for contribs:
`layer_index_normalization: typing.Callable[[float, int], float]|None = lambda contrib, layer_idx: contrib`
"""
dla_data: dict[
str, Float[np.ndarray, "layer head/neuron"]
] = compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
dla_data: dict[str, Float[np.ndarray, "layer head/neuron"]] = (
compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
)
)
if layer_index_normalization is not None:
dla_data = {
Expand Down
6 changes: 2 additions & 4 deletions maze_transformer/mechinterp/logit_attrib_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def get_token_first_index(search_token: str, token_list: list[str]) -> int:
class DLAProtocol(typing.Protocol):
"""should take a dataset's tokens, and return a tuple of (prompts, targets)"""

def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup:
...
def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup: ...


class DLAProtocolFixed(typing.Protocol):
Expand All @@ -32,8 +31,7 @@ class DLAProtocolFixed(typing.Protocol):
this variant signifies it's ready to be used -- no keyword arguments are needed
"""

def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup:
...
def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup: ...


def token_after_fixed_start_token(
Expand Down
18 changes: 9 additions & 9 deletions maze_transformer/mechinterp/logit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def logit_diff_residual_stream(
vocab_tensor: Float[torch.Tensor, "d_vocab"] = torch.arange(
d_vocab, dtype=torch.long
)
vocab_residual_directions: Float[
torch.Tensor, "d_vocab d_model"
] = model.tokens_to_residual_directions(vocab_tensor)
vocab_residual_directions: Float[torch.Tensor, "d_vocab d_model"] = (
model.tokens_to_residual_directions(vocab_tensor)
)
# get embedding of answer tokens
answer_residual_directions = vocab_residual_directions[tokens_correct]
# get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens}
Expand All @@ -108,12 +108,12 @@ def logit_diff_residual_stream(
][:, -1, :]

# scaling the values in residual stream with layer norm
scaled_final_token_residual_stream: Float[
torch.Tensor, "samples d_model"
] = cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
scaled_final_token_residual_stream: Float[torch.Tensor, "samples d_model"] = (
cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
)
)

# measure similarity between the logit diff directions and the residual stream at final layer directions
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def mazeplot_attention(
node_values=node_values,
color_map=cmap,
target_token_coord=target_coord,
preceeding_tokens_coords=[final_prompt_coord]
if final_prompt_coord is not None
else None,
preceeding_tokens_coords=(
[final_prompt_coord] if final_prompt_coord is not None else None
),
colormap_center=colormap_center_val,
colormap_max=colormap_max,
hide_colorbar=hide_colorbar,
Expand Down
14 changes: 8 additions & 6 deletions maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def process_tokens_for_pca(tokenizer: MazeTokenizer) -> list[TokenPlottingInfo]:
tokenizer.token_arr,
tokens_coords,
[
coordinate_to_color(coord, max_val=max_coord)
if isinstance(coord, tuple)
else (0.0, 1.0, 0.0)
(
coordinate_to_color(coord, max_val=max_coord)
if isinstance(coord, tuple)
else (0.0, 1.0, 0.0)
)
for coord in tokens_coords
],
)
Expand Down Expand Up @@ -249,9 +251,9 @@ def compute_distances_and_correlation(
# embedding_distances /= embedding_distances.max()

# Convert the distances to a square matrix
embedding_distances_matrix: Float[
np.ndarray, "n_coord_tokens n_coord_tokens"
] = squareform(embedding_distances)
embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"] = (
squareform(embedding_distances)
)

# Calculate the correlation between the embedding and coordinate distances
coordinate_coordinates: Float[np.ndarray, "n_coord_tokens 2"] = np.array(
Expand Down
32 changes: 23 additions & 9 deletions maze_transformer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ class HuggingMazeTokenizer(PreTrainedTokenizer):
unk_token: str = "<UNK>"

vocab_size: int = 0
additional_special_tokens: list[str] = [
x for x in SPECIAL_TOKENS.values() if x not in [SPECIAL_TOKENS.PADDING]
]

# Overwrite class attributes
# as of https://github.com/neelnanda-io/TransformerLens/pull/344 this gets overwritten to "right" on `HookedTransformer.__init__()`
Expand Down Expand Up @@ -68,6 +65,9 @@ def __init__(
# utils.py:1075, in get_tokenizer_with_bos(tokenizer)
# -> 1075 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
self.init_kwargs["name_or_path"] = self.name_or_path
# utils.py:1075, in get_tokenizer_with_bos(tokenizer)
# -> 1078 add_bos_token = init_kwargs.pop("add_bos_token", None)
self.init_kwargs["add_bos_token"] = True

assert isinstance(
seq_len_max, int
Expand All @@ -84,13 +84,12 @@ def __init__(
vocab[self.unk_token] = len(vocab)
self.vocab: dict[str, int] = vocab

self.added_tokens_encoder: dict[str, int] = vocab
self.added_tokens_decoder: dict[int, str] = {
i: token for token, i in vocab.items()
}
special_tokens = list(SPECIAL_TOKENS.values())
normal_tokens = [x for x in token_arr if x not in special_tokens]
self._add_tokens(normal_tokens)
self._add_tokens(special_tokens)

self.unique_no_split_tokens = token_arr
self._create_trie(self.unique_no_split_tokens)
self.unique_no_split_tokens = token_arr # Trie is updated automatically?

# IDs specified during construction
self.bos_token_id: int = self.added_tokens_encoder[self.bos_token]
Expand All @@ -113,12 +112,22 @@ def __call__(self, text, **kwargs) -> BatchEncoding:

def _tokenize(self, text: str, **kwargs) -> list[str]:
assert len(kwargs) == 0, f"kwargs not supported: {kwargs}"
if text == " ": # In transformers ^4.34, this input is passed.
return (
[]
) # Necessary to maintain output of `PreTrainedTokenizer.tokenize` from transformers <=4.33

return text.split(" ")

def _convert_token_to_id(self, token: str) -> int:
if token in self.vocab:
return self.vocab[token]
elif (
token == " "
): # for some reason transformers trie now returns ' ' as tokens
raise ValueError(
f"Found a space token in `_convert_token_to_id`: '{token}'"
)
else:
raise ValueError(f"Token not in vocab: '{token}'")

Expand Down Expand Up @@ -151,3 +160,8 @@ def to_ascii(

lattice_maze = LatticeMaze.from_tokens(str_sequence, self._maze_tokenizer)
return MazePlot(lattice_maze).to_ascii()

def get_vocab(self) -> dict[str, int]:
if hasattr(self, "vocab"):
return self.vocab
return {}
34 changes: 18 additions & 16 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def get_intervals(
)

except ValueError as e:
_debug_vals: str = f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
_debug_vals: str = (
f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
)
raise ValueError(f"{_debug_vals}\ntriggered error:\n{e}") from e

# disable if set to 0 or negative
Expand All @@ -235,9 +237,9 @@ def get_intervals(
# actually return the intervals
if mod_batch_size:
return {
k: max(1, v // self.batch_size)
if isinstance(v, int)
else v # if float, leave it as is since its float("inf")
k: (
max(1, v // self.batch_size) if isinstance(v, int) else v
) # if float, leave it as is since its float("inf")
for k, v in intervals_new.items()
}
else:
Expand Down Expand Up @@ -459,9 +461,11 @@ def summary(self) -> str:
"model_cfg": self.model_cfg.summary(),
"train_cfg": self.train_cfg.summary(),
"pretrainedtokenizer_kwargs": self.pretrainedtokenizer_kwargs,
"maze_tokenizer": self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None,
"maze_tokenizer": (
self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None
),
}

@property
Expand All @@ -473,7 +477,9 @@ def tokenizer(self) -> PreTrainedTokenizer:
"""get a tokenizer via a pretrainedtokenizer_kwargs, or a hugging maze tokenizer"""
if self._tokenizer is None:
if self.pretrainedtokenizer_kwargs is not None:
return PreTrainedTokenizer(**self.pretrainedtokenizer_kwargs)
raise ValueError(
"Obsolete tokenizer initialization, caller should revise `ConfigHolder` initialization."
)
elif self.maze_tokenizer is not None:
return HuggingMazeTokenizer(
seq_len_max=self.dataset_cfg.seq_len_max,
Expand All @@ -486,8 +492,7 @@ def tokenizer(self) -> PreTrainedTokenizer:
)
else:
raise ValueError("no tokenizer specified")
else:
return self._tokenizer
return self._tokenizer

@cached_property
def hooked_transformer_cfg(self) -> HookedTransformerConfig:
Expand Down Expand Up @@ -655,12 +660,9 @@ def _load_state_dict_wrapper(
self.zanj_model_config.model_cfg.weight_processing["are_layernorms_folded"]
or fold_ln
)
self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] = self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] or (
not recover_exact
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"] = (
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"]
or (not recover_exact)
)

self.load_and_process_state_dict(
Expand Down
10 changes: 4 additions & 6 deletions maze_transformer/training/train_save_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ class TRAIN_SAVE_FILES:
config_holder: str = "config.json"
checkpoints: str = "checkpoints"
log: str = "log.jsonl"
model_checkpt_zanj: Callable[
[int], str
] = lambda iteration: f"model.iter_{iteration}.zanj"
model_checkpt_zanj: Callable[[int], str] = (
lambda iteration: f"model.iter_{iteration}.zanj"
)
model_final_zanj: str = "model.final.zanj"
model_run_dir: Callable[
[ConfigHolder], str
] = (
model_run_dir: Callable[[ConfigHolder], str] = (
lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
Loading

0 comments on commit 495d8d3

Please sign in to comment.