Skip to content

Commit

Permalink
moved plot logits out
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Sep 5, 2023
1 parent 61f6f95 commit 729dc1f
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 107 deletions.
99 changes: 99 additions & 0 deletions maze_transformer/mechinterp/plot_logits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Generic
import typing

# Numerical Computing
import numpy as np
import torch
from jaxtyping import Float, Int, Bool
import matplotlib.pyplot as plt

from muutils.misc import shorten_numerical_to_str

# Our Code
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode

_DEFAULT_SUBPLOTS_KWARGS: dict = dict(
figsize=(20, 20),
height_ratios=[3, 1],
)

def plot_logits(
last_tok_logits: Float[torch.Tensor, "n_mazes d_vocab"],
target_idxs: Int[torch.Tensor, "n_mazes"],
tokenizer: MazeTokenizer,
n_bins: int = 50,
mark_incorrect: bool = True,
mark_correct: bool = False,
subplots_kwargs: dict|None = None,
show: bool = True,
) -> None:
# set up figure
# --------------------------------------------------
n_mazes: int; d_vocab: int
n_mazes, d_vocab = last_tok_logits.shape
if subplots_kwargs is None:
subplots_kwargs = _DEFAULT_SUBPLOTS_KWARGS

fig, (ax_all, ax_sum) = plt.subplots(2, 1, **{**_DEFAULT_SUBPLOTS_KWARGS, **subplots_kwargs})

# fig.subplots_adjust(hspace=0.5, bottom=0.1, top=0.9, left=0.1, right=0.9)

# plot heatmap of logits
# --------------------------------------------------
# all vocab elements
ax_all.set_xlabel("vocab element logit")
ax_all.set_ylabel("maze index")
# add vocab as xticks
ax_all.set_xticks(ticks=np.arange(d_vocab), labels=tokenizer.token_arr, rotation=90)
ax_all.imshow(last_tok_logits.numpy(), aspect="auto")
# set colorbar
plt.colorbar(ax_all.imshow(last_tok_logits.numpy(), aspect="auto"), ax=ax_all)

if mark_correct:
# place yellow x at max logit token
ax_all.scatter(last_tok_logits.argmax(dim=1), np.arange(n_mazes), marker="x", color="yellow")
# place red dot at correct token
ax_all.scatter(target_idxs, np.arange(n_mazes), marker=".", color="red")
if mark_incorrect:
raise ValueError("mark_correct and mark_incorrect cannot both be True")

if mark_incorrect:
# place a red dot wherever the max logit token is not the correct token
ax_all.scatter(
last_tok_logits.argmax(dim=1)[last_tok_logits.argmax(dim=1) != target_idxs],
np.arange(n_mazes)[last_tok_logits.argmax(dim=1) != target_idxs],
marker=".",
color="red",
)

# histogram of logits for correct and incorrect tokens
# --------------------------------------------------
ax_sum.set_ylabel("probability density")
ax_sum.set_xlabel("logit value")

# get correct token logits
correct_token_logits: Float[torch.Tensor, "n_mazes"] = torch.gather(last_tok_logits, 1, target_idxs.unsqueeze(1)).squeeze(1)
mask = torch.ones(n_mazes, d_vocab, dtype=torch.bool)
mask.scatter_(1, target_idxs.unsqueeze(1), False)
other_token_logits: Float[torch.Tensor, "n_mazes d_vocab-1"] = last_tok_logits[mask].reshape(n_mazes, d_vocab - 1)

# plot histogram
bins: Float[np.ndarray, "n_bins"] = np.linspace(last_tok_logits.min(), last_tok_logits.max(), n_bins)
ax_sum.hist(
correct_token_logits.numpy(),
density=True,
bins=bins,
label="correct token",
)
ax_sum.hist(
other_token_logits.numpy().flatten(),
density=True,
bins=bins,
label="other token",
)
ax_sum.legend()

if show:
plt.show()

return fig, (ax_all, ax_sum)
148 changes: 41 additions & 107 deletions notebooks/direct_logit_attribution.ipynb

Large diffs are not rendered by default.

0 comments on commit 729dc1f

Please sign in to comment.