From f5594583b9c09de6b8ca2ceb2ffaf2c4fc93b6ff Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 4 Sep 2023 21:15:38 -0600 Subject: [PATCH] direct logit attribution logit diff wip i have no idea why, but the scaling is wrong. The relationship is clearly linear, but its different depending on whether it's diff to all other tokens or some random set of tokens --- notebooks/direct_logit_attribution.ipynb | 269 ++++++++++++++++------- 1 file changed, 187 insertions(+), 82 deletions(-) diff --git a/notebooks/direct_logit_attribution.ipynb b/notebooks/direct_logit_attribution.ipynb index b98d5655..44b4b17b 100644 --- a/notebooks/direct_logit_attribution.ipynb +++ b/notebooks/direct_logit_attribution.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 106, "metadata": { "pycharm": { "name": "#%%\n" @@ -57,6 +57,7 @@ "# Numerical Computing\n", "import numpy as np\n", "import torch\n", + "import pandas as pd\n", "# import torch.nn.functional as F\n", "from fancy_einsum import einsum\n", "import einops\n", @@ -81,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": { "pycharm": { "name": "#%%\n" @@ -98,10 +99,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -129,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -164,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": { "pycharm": { "name": "#%%\n" @@ -211,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -249,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -293,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -305,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -340,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -358,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -379,7 +380,7 @@ " ))" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -395,7 +396,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -417,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -449,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -457,15 +458,25 @@ "def logits_to_avg_logit_diff(\n", " final_logits: Float[torch.Tensor, \"n_mazes d_vocab\"],\n", " answer_tokens: Int[torch.Tensor, \"n_mazes\"],\n", + " compare_to: Int[torch.Tensor, \"n_mazes\"]|None = None,\n", " per_prompt: bool = True,\n", " ) -> Float[torch.Tensor, \"n_mazes\"]|float:\n", "\n", " # logit on the answer token for each sample\n", " answer_logits: Float[torch.Tensor, \"n_mazes\"] = torch.gather(final_logits, 1, answer_tokens.unsqueeze(1)).squeeze(1)\n", - " # logits of all tokens for each sample\n", - " all_logits: Float[torch.Tensor, \"n_mazes\"] = torch.sum(final_logits, dim=1)\n", + " \n", + " output: Float[torch.Tensor, \"n_mazes\"]\n", + " if compare_to is None:\n", + " # logits of all tokens for each sample\n", + " all_logits: Float[torch.Tensor, \"n_mazes\"] = torch.sum(final_logits, dim=1)\n", + " output = answer_logits - (all_logits - answer_logits)\n", + " else:\n", + " # specifically the comparison tokens\n", + " compare_to_logits: Float[torch.Tensor, \"n_mazes\"] = torch.gather(final_logits, 1, compare_to.unsqueeze(1)).squeeze(1)\n", + " output = answer_logits - compare_to_logits\n", + "\n", + " assert output.shape == answer_tokens.shape\n", "\n", - " output: Float[torch.Tensor, \"n_mazes\"] = answer_logits - (all_logits - answer_logits)\n", " if per_prompt:\n", " return output\n", " else:\n", @@ -476,96 +487,169 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 101, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "avg logit diff of target: 31.570226669311523\n", - "avg logit diff of predicted: 31.570226669311523\n", - "avg logit diff of sampled: 31.570226669311523\n", - "avg logit diff of random: 0.02409425750374794\n" - ] - } - ], + "outputs": [], "source": [ + "def logit_diff_residual_stream(\n", + "\tmodel: ZanjHookedTransformer,\n", + "\tcache: ActivationCache,\n", + "\tanswer_tokens: Int[torch.Tensor, \"n_mazes\"],\n", + "\tcompare_to: Int[torch.Tensor, \"n_mazes\"]|None = None,\n", + ") -> dict:\n", + "\t# embed the whole vocab first\n", + "\td_vocab: int = model.zanj_model_config.maze_tokenizer.vocab_size\n", + "\tvocab_tensor: Float[torch.Tensor, \"d_vocab\"] = torch.arange(d_vocab, dtype=torch.long)\n", + "\tvocab_residual_directions = model.tokens_to_residual_directions(vocab_tensor)\n", + "\t# get embedding of answer tokens\n", + "\tanswer_residual_directions = vocab_residual_directions[answer_tokens]\n", + "\t# get the directional difference\n", + "\tlogit_diff_directions: Float[torch.Tensor, \"n_mazes d_model\"]\n", + "\tif compare_to is None:\n", + "\t\tlogit_diff_directions = answer_residual_directions - vocab_residual_directions[~answer_tokens]\n", + "\telse:\n", + "\t\tlogit_diff_directions = answer_residual_directions - vocab_residual_directions[compare_to]\n", + "\n", "\n", - "for k, d in {\n", - " \"target\": DATASET_TARGET_IDS, \n", - " \"predicted\": LAST_TOK_LOGITS.argmax(dim=-1), \n", - " \"sampled\": torch.multinomial(torch.softmax(LAST_TOK_LOGITS, dim=-1), num_samples=1).squeeze(-1),\n", - " \"random\": torch.randint_like(DATASET_TARGET_IDS, low=0, high=d_vocab),\n", - "}.items():\n", - "\tresult: float = logits_to_avg_logit_diff(\n", - "\t\tfinal_logits=LAST_TOK_LOGITS, \n", - "\t\tanswer_tokens=d,\n", - "\t\tper_prompt=False,\n", + "\t# get the values from the cache at the last layer and last token\n", + "\tfinal_token_residual_stream = cache[\"resid_post\", -1][:, -1, :]\n", + "\t# scaling the values in residual stream with layer norm\n", + "\tscaled_final_token_residual_stream = cache.apply_ln_to_stack(\n", + "\t\tfinal_token_residual_stream, layer = -1, pos_slice=-1,\n", "\t)\n", - "\tprint(f\"avg logit diff of {k}: {result}\")" + "\n", + "\n", + "\taverage_logit_diff = torch.dot(\n", + "\t\tscaled_final_token_residual_stream.flatten(),\n", + "\t\tlogit_diff_directions.flatten(),\n", + "\t) / logit_diff_directions.shape[0]\n", + "\n", + "\treturn average_logit_diff.item()\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 142, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "vocab_residual_directions.shape = torch.Size([75, 128])\n", - "answer_residual_directions.shape = torch.Size([100, 128])\n", - "logit_diff_directions.shape = torch.Size([100, 128])\n" + " test compare_to result_orig result_res diff \\\n", + "0 target all 31.570227 13.288536 18.281691 \n", + "1 predicted all 31.570227 13.288536 18.281691 \n", + "2 sampled all 31.570227 13.288536 18.281691 \n", + "3 noise=1.0 all 31.570227 13.288536 18.281691 \n", + "4 noise=1.0722672220103233 all 31.570227 13.288536 18.281691 \n", + ".. ... ... ... ... ... \n", + "203 noise=811.1308307896873 random 0.325303 0.350737 -0.025433 \n", + "204 noise=869.7490026177834 random 1.277788 1.303297 -0.025508 \n", + "205 noise=932.60334688322 random 0.080973 0.156748 -0.075774 \n", + "206 noise=1000.0 random 0.775237 0.807143 -0.031906 \n", + "207 random random 0.230697 0.403727 -0.173030 \n", + "\n", + " ratio \n", + "0 2.375749 \n", + "1 2.375749 \n", + "2 2.375749 \n", + "3 2.375749 \n", + "4 2.375749 \n", + ".. ... \n", + "203 0.927486 \n", + "204 0.980428 \n", + "205 0.516585 \n", + "206 0.960470 \n", + "207 0.571418 \n", + "\n", + "[208 rows x 6 columns]\n" ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "def logit_diff_residual_stream(\n", + "def logits_diff(\n", "\tmodel: ZanjHookedTransformer,\n", "\tcache: ActivationCache,\n", - "\tanswer_tokens: Int[torch.Tensor, \"n_mazes\"],\n", - ") -> dict:\n", - "\t# embed the whole vocab first\n", - "\tvocab_tensor: Float[torch.Tensor, \"d_vocab\"] = torch.arange(model.zanj_model_config.maze_tokenizer.vocab_size, dtype=torch.long)\n", - "\tvocab_residual_directions = model.tokens_to_residual_directions(vocab_tensor)\n", - "\t# get embedding of answer tokens\n", - "\tanswer_residual_directions = vocab_residual_directions[answer_tokens]\n", + "\tdataset_target_ids: Int[torch.Tensor, \"n_mazes\"],\n", + "\tlast_tok_logits: Float[torch.Tensor, \"n_mazes d_vocab\"],\n", + "\tnoise_sigmas: list[float] = [1, 2, 3, 5, 10],\n", + ") -> pd.DataFrame:\n", + "\t\n", + "\ttest_logits: dict[str, Float[torch.Tensor, \"n_mazes\"]] = {\n", + "\t\t\"target\": dataset_target_ids, \n", + "\t\t\"predicted\": last_tok_logits.argmax(dim=-1), \n", + "\t\t\"sampled\": torch.multinomial(torch.softmax(last_tok_logits, dim=-1), num_samples=1).squeeze(-1),\n", + "\t\t**{\n", + "\t\t\tf\"noise={s}\": (last_tok_logits + s*torch.randn_like(last_tok_logits)).argmax(dim=-1)\n", + "\t\t\tfor s in noise_sigmas\n", + "\t\t},\n", + "\t\t\"random\": torch.randint_like(dataset_target_ids, low=0, high=d_vocab),\n", + "\t}\n", + "\tcompare_rand: Float[torch.Tensor, \"n_mazes\"] = torch.randint_like(dataset_target_ids, low=0, high=d_vocab)\n", "\n", - "\t# get the difference in direction between the answer token and the rest of the vocab\t\n", - "\t# for i in range(len(answer_tokens)):\n", - "\t# \t_temp = answer_residual_directions[i] - vocab_residual_directions[~answer_tokens[i]]\n", - "\t# \tprint(f\"{_temp.shape = }\")\n", + "\toutputs: list[dict] = list()\n", "\n", - "\t# logit_diff_directions = torch.cat(\n", - "\t# \t[\n", - "\t# \t\tanswer_residual_directions[i] - vocab_residual_directions[~answer_tokens[i]]\n", - "\t# \t\tfor i in range(len(answer_tokens))\n", - "\t# \t],\n", - "\t# \tdim=-1,\n", - "\t# )\n", - "\tlogit_diff_directions = answer_residual_directions - vocab_residual_directions[~answer_tokens]\n", + "\tfor compare_to in [None, compare_rand]:\n", + "\t\tfor k, d in test_logits.items():\n", + "\t\t\tresult_orig: float = logits_to_avg_logit_diff(\n", + "\t\t\t\tfinal_logits=last_tok_logits, \n", + "\t\t\t\tanswer_tokens=d,\n", + "\t\t\t\tper_prompt=False,\n", + "\t\t\t\tcompare_to=compare_to,\n", + "\t\t\t)\n", + "\t\t\tresult_res: float = logit_diff_residual_stream(\n", + "\t\t\t\tmodel=model,\n", + "\t\t\t\tcache=cache,\n", + "\t\t\t\tanswer_tokens=d,\n", + "\t\t\t\tcompare_to=compare_to,\n", + "\t\t\t)\n", + "\t\t\t# print(f\"logit diff of {k}\\tcompare:\\t{'all' if compare_to is None else 'random'}\\t{result = }\\t{result_res = }\")\n", + "\t\t\toutputs.append(dict(\n", + "\t\t\t\ttest=k,\n", + "\t\t\t\tcompare_to=\"all\" if compare_to is None else \"random\",\n", + "\t\t\t\tresult_orig=result_orig,\n", + "\t\t\t\tresult_res=result_res,\n", + "\t\t\t))\n", "\n", - "\tprint(f\"{vocab_residual_directions.shape = }\")\n", - "\tprint(f\"{answer_residual_directions.shape = }\")\n", - "\tprint(f\"{logit_diff_directions.shape = }\")\n", + "\tdf_out: pd.DataFrame = pd.DataFrame(outputs)\n", + "\tdf_out[\"diff\"] = df_out[\"result_orig\"] - df_out[\"result_res\"]\n", + "\tdf_out[\"ratio\"] = df_out[\"result_orig\"] / df_out[\"result_res\"]\n", "\n", - "\t# get the values from the cache at the last layer and last token\n", - "\tfinal_token_residual_stream = cache[\"resid_post\", -1][:, -1, :]\n", - "\t# scaling the values in residual stream with layer norm\n", - "\tscaled_final_token_residual_stream = cache.apply_ln_to_stack(\n", - "\t\tfinal_token_residual_stream, layer = -1, pos_slice=-1,\n", - "\t)\n", "\n", - "\taverage_logit_diff = einsum(\n", - "\t\t\"batch d_model, batch d_model -> \", \n", - "\t\tscaled_final_token_residual_stream, \n", - "\t\tlogit_diff_directions,\n", - "\t) / len(answer_tokens)\n", + "\treturn df_out\n", + "\n", + "LOGIT_DIFF_DF: pd.DataFrame = logits_diff(\n", + "\tmodel=MODEL,\n", + "\tcache=CACHE,\n", + "\tdataset_target_ids=DATASET_TARGET_IDS,\n", + "\tlast_tok_logits=LAST_TOK_LOGITS,\n", + "\tnoise_sigmas=np.logspace(0, 3, 100),\n", + ")\n", "\n", + "print(LOGIT_DIFF_DF)\n", "\n", - "logit_diff_residual_stream(MODEL, CACHE, DATASET_TARGET_IDS)" + "# plt.scatter(LOGIT_DIFF_DF['result_orig'], LOGIT_DIFF_DF['result_res'])\n", + "# scatter separately for \"all\" vs \"random\"\n", + "fig, ax = plt.subplots()\n", + "for compare_to in [\"all\", \"random\"]:\n", + "\tdf = LOGIT_DIFF_DF[LOGIT_DIFF_DF[\"compare_to\"] == compare_to]\n", + "\tax.scatter(df['result_orig'], df['result_res'], label=compare_to)\n", + "ax.legend()\n", + "plt.xlabel('result_orig')\n", + "plt.ylabel('result_res')\n", + "plt.title('Scatter Plot between result_orig and result_res')\n", + "plt.show()\n" ] }, { @@ -573,6 +657,27 @@ "execution_count": null, "metadata": {}, "outputs": [], + "source": [ + "# linear " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "HookedTransformer.__init__() missing 1 required positional argument: 'cfg'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[16], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m answer_residual_directions \u001b[39m=\u001b[39m MODEL\n\u001b[0;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformer_lens\u001b[39;00m \u001b[39mimport\u001b[39;00m HookedTransformer\n\u001b[1;32m----> 3\u001b[0m HookedTransformer()\u001b[39m.\u001b[39mtokens_to_residual_directions()\n", + "\u001b[1;31mTypeError\u001b[0m: HookedTransformer.__init__() missing 1 required positional argument: 'cfg'" + ] + } + ], "source": [ "answer_residual_directions = MODEL\n", "from transformer_lens import HookedTransformer\n",