Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
seplee committed Jul 23, 2024
1 parent 7a91220 commit 3025de9
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ def set_SAE(self, sae_name):
self.cfg_dict = cfg_dict

def _get_sae_out_and_feature_activations(self):
# given the words in steering_vectore_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated
# given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated
sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True)
sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name])
# get top_k of 1
# self.sae_out = sae.decode(sv_feature_acts)
return self.sae.decode(sv_feature_acts), sv_feature_acts

def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs):
Expand Down Expand Up @@ -101,12 +99,10 @@ def _get_steering_hooks(self):
def _run_generate(self, example_prompt, steering_on: bool):

self.model.reset_hooks()
steer_hooks = self._get_steering_hooks()
editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks]
# editing_hooks = [(self.sae_id, steer_hook)]
# ^^change this to support steer_hooks being a list of steer_hooks
print(f"steering by {len(editing_hooks)} hooks")
if steering_on:
steer_hooks = self._get_steering_hooks()
editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks]
print(f"steering by {len(editing_hooks)} hooks")
res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs)
else:
tokenized = self.model.to_tokens([example_prompt])
Expand All @@ -129,12 +125,12 @@ def generate(self, message: str, steering_on: bool):



MODEL = "gemma-2b"
PRETRAINED_SAE = "gemma-2b-res-jb"
# MODEL = "gemma-2b"
# PRETRAINED_SAE = "gemma-2b-res-jb"
MODEL = "gpt2-small"
PRETRAINED_SAE = "gpt2-small-res-jb"
LAYER = 10
chatbot_model = Inference(MODEL,PRETRAINED_SAE, LAYER)
chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER)


import time
Expand Down Expand Up @@ -187,7 +183,6 @@ def slow_echo_steering(message, history):
with gr.Row():
temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True)

# Set up an action when the sliders change
temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[])
coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[])
chatbot_model.set_steering_vector_prompt(steering_prompt.value)
Expand Down

0 comments on commit 3025de9

Please sign in to comment.