Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Sep 5, 2023
1 parent 729dc1f commit 0d0d652
Showing 1 changed file with 94 additions and 11 deletions.
105 changes: 94 additions & 11 deletions notebooks/direct_logit_attribution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -449,52 +449,135 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# From Neels explanatory notebook: https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb\n",
"def logits_to_ave_logit_diff(\n",
"def logits_to_avg_logit_diff(\n",
" final_logits: Float[torch.Tensor, \"n_mazes d_vocab\"],\n",
" answer_tokens: Int[torch.Tensor, \"n_mazes\"],\n",
" ) -> Float[torch.Tensor, \"n_mazes\"]:\n",
" per_prompt: bool = True,\n",
" ) -> Float[torch.Tensor, \"n_mazes\"]|float:\n",
"\n",
" # logit on the answer token for each sample\n",
" answer_logits: Float[torch.Tensor, \"n_mazes\"] = torch.gather(final_logits, 1, answer_tokens.unsqueeze(1)).squeeze(1)\n",
" # logits of all tokens for each sample\n",
" all_logits: Float[torch.Tensor, \"n_mazes\"] = torch.sum(final_logits, dim=1)\n",
"\n",
" return answer_logits - (all_logits - answer_logits)\n",
" output: Float[torch.Tensor, \"n_mazes\"] = answer_logits - (all_logits - answer_logits)\n",
" if per_prompt:\n",
" return output\n",
" else:\n",
" return output.mean().item()\n",
"\n",
" # return answer_logits / (all_logits - answer_logits)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"avg_logit_diff.shape = torch.Size([100])\n",
"avg_logit_diff.mean().item() = 31.570226669311523\n"
"avg logit diff of target: 31.570226669311523\n",
"avg logit diff of predicted: 31.570226669311523\n",
"avg logit diff of sampled: 31.570226669311523\n",
"avg logit diff of random: 0.02409425750374794\n"
]
}
],
"source": [
"avg_logit_diff: Float[torch.Tensor, \"n_mazes\"] = logits_to_ave_logit_diff(final_logits=LAST_TOK_LOGITS, answer_tokens=DATASET_TARGET_IDS)\n",
"\n",
"print(f\"{avg_logit_diff.shape = }\")\n",
"print(f\"{avg_logit_diff.mean().item() = }\")"
"for k, d in {\n",
" \"target\": DATASET_TARGET_IDS, \n",
" \"predicted\": LAST_TOK_LOGITS.argmax(dim=-1), \n",
" \"sampled\": torch.multinomial(torch.softmax(LAST_TOK_LOGITS, dim=-1), num_samples=1).squeeze(-1),\n",
" \"random\": torch.randint_like(DATASET_TARGET_IDS, low=0, high=d_vocab),\n",
"}.items():\n",
"\tresult: float = logits_to_avg_logit_diff(\n",
"\t\tfinal_logits=LAST_TOK_LOGITS, \n",
"\t\tanswer_tokens=d,\n",
"\t\tper_prompt=False,\n",
"\t)\n",
"\tprint(f\"avg logit diff of {k}: {result}\")"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"vocab_residual_directions.shape = torch.Size([75, 128])\n",
"answer_residual_directions.shape = torch.Size([100, 128])\n",
"logit_diff_directions.shape = torch.Size([100, 128])\n"
]
}
],
"source": [
"def logit_diff_residual_stream(\n",
"\tmodel: ZanjHookedTransformer,\n",
"\tcache: ActivationCache,\n",
"\tanswer_tokens: Int[torch.Tensor, \"n_mazes\"],\n",
") -> dict:\n",
"\t# embed the whole vocab first\n",
"\tvocab_tensor: Float[torch.Tensor, \"d_vocab\"] = torch.arange(model.zanj_model_config.maze_tokenizer.vocab_size, dtype=torch.long)\n",
"\tvocab_residual_directions = model.tokens_to_residual_directions(vocab_tensor)\n",
"\t# get embedding of answer tokens\n",
"\tanswer_residual_directions = vocab_residual_directions[answer_tokens]\n",
"\n",
"\t# get the difference in direction between the answer token and the rest of the vocab\t\n",
"\t# for i in range(len(answer_tokens)):\n",
"\t# \t_temp = answer_residual_directions[i] - vocab_residual_directions[~answer_tokens[i]]\n",
"\t# \tprint(f\"{_temp.shape = }\")\n",
"\n",
"\t# logit_diff_directions = torch.cat(\n",
"\t# \t[\n",
"\t# \t\tanswer_residual_directions[i] - vocab_residual_directions[~answer_tokens[i]]\n",
"\t# \t\tfor i in range(len(answer_tokens))\n",
"\t# \t],\n",
"\t# \tdim=-1,\n",
"\t# )\n",
"\tlogit_diff_directions = answer_residual_directions - vocab_residual_directions[~answer_tokens]\n",
"\n",
"\tprint(f\"{vocab_residual_directions.shape = }\")\n",
"\tprint(f\"{answer_residual_directions.shape = }\")\n",
"\tprint(f\"{logit_diff_directions.shape = }\")\n",
"\n",
"\t# get the values from the cache at the last layer and last token\n",
"\tfinal_token_residual_stream = cache[\"resid_post\", -1][:, -1, :]\n",
"\t# scaling the values in residual stream with layer norm\n",
"\tscaled_final_token_residual_stream = cache.apply_ln_to_stack(\n",
"\t\tfinal_token_residual_stream, layer = -1, pos_slice=-1,\n",
"\t)\n",
"\n",
"\taverage_logit_diff = einsum(\n",
"\t\t\"batch d_model, batch d_model -> \", \n",
"\t\tscaled_final_token_residual_stream, \n",
"\t\tlogit_diff_directions,\n",
"\t) / len(answer_tokens)\n",
"\n",
"\n",
"logit_diff_residual_stream(MODEL, CACHE, DATASET_TARGET_IDS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"answer_residual_directions = MODEL\n",
"from transformer_lens import HookedTransformer\n",
"HookedTransformer().tokens_to_residual_directions()"
]
},
{
"attachments": {},
Expand Down

0 comments on commit 0d0d652

Please sign in to comment.