Skip to content

Commit

Permalink
Add instruct example
Browse files Browse the repository at this point in the history
  • Loading branch information
shahbuland committed Dec 13, 2022
1 parent 8dcfb2a commit 51517fd
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cheese/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def send(msg : Any):
send(self.get_stats())
elif msg[0] == msg_constants.DRAW:
self.draw()
else:
print("Warning: Unknown message received", msg)
time.sleep(listen_every)

@rabbitmq_callback
Expand Down
197 changes: 197 additions & 0 deletions examples/instruct_hf_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
This example does an instruct type annotation task in which labellers are given
multiple prompt completions and asked to rank them in order of preference.
It collects and write a dataset of preferences for completions.
"""

from dataclasses import dataclass
from typing import List, Iterable

from transformers import pipeline
import gradio as gr
from cheese.pipeline.generative import GenerativePipeline
from cheese.models import BaseModel
from cheese.data import BatchElement
from cheese.client.gradio_client import GradioFront
from cheese import CHEESE

@dataclass
class LMGenerationElement(BatchElement):
query : str = None
completions : List[str] = None
rankings : List[int] = None # Ordering for the completions w.r.t indices

class LMPipeline(GenerativePipeline):
def __init__(self, n_samples = 5, **kwargs):
super().__init__(**kwargs)

self.n_samples = n_samples
self.pipe = pipeline(task="text-generation", model = 'gpt2', device=0)
self.pipe.tokenizer.pad_token_id = self.pipe.model.config.eos_token_id
# prevents annoying messages


self.init_buffer()

def generate(self, model_input : Iterable[str]) -> List[LMGenerationElement]:
"""
Generates a batch of elements using the pipeline's iterator.
"""
print("Generate called")
elements = []
for i in range(self.batch_size):
query = model_input[i]
completions = self.pipe(query, max_length=100, num_return_sequences=self.n_samples)
completions = [completion["generated_text"] for completion in completions]
elements.append(LMGenerationElement(query=query, completions=completions))
return elements

def extract_data(self, batch_element : LMGenerationElement) -> dict:
"""
Extracts data from a batch element.
"""
return {
"query" : batch_element.query,
"completions" : batch_element.completions,
"rankings" : batch_element.rankings
}

def make_iter(length : int = 20):
print("Creating prompt iterator...")
pipe = pipeline(task="text-generation", model = 'gpt2', device=0)
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
chunk_size = 16
meta_prompt = f"As an example, below is a list of {chunk_size + 3} prompts you could feed to a language model:\n"+\
"\"What is the capital of France?\"\n"+\
"\"Write a story about geese\"\n"+\
"\"Tell me a fun fact about rabbits\"\n"

def extract_prompts(entire_generation : str):
generation = entire_generation[len(meta_prompt):]
prompts = generation.split("\n")
prompts = [prompt[1:-1] for prompt in prompts] # Remove quotes
return prompts[:chunk_size]

prompt_buffer = []

while len(prompt_buffer) < length:
prompts = pipe(meta_prompt, max_length=128, num_return_sequences=chunk_size)
prompts = sum([extract_prompts(prompt["generated_text"]) for prompt in prompts], [])
prompt_buffer += prompts

del pipe

return iter(prompt_buffer)

class LMFront(GradioFront):
def main(self):
pressed = gr.State([])
with gr.Column():
gr.Button("On the left you will see a prompt. On the right you will "+ \
"see various possible completions. Select the completions in order of "+ \
"best to worst", interactive = False, show_label = False)
with gr.Row():
query = gr.Textbox("Prompt", interactive = False, show_label = False)
with gr.Column():
gr.Textbox("Completions:", interactive = False, show_label = False)

completions = [gr.Button("", interactive = True) for _ in range(5)]


submit = gr.Button("Submit")

# When a button is pressed, append index to state, and make button not visible

def press_button(i, pressed_val):
print("Pressed button", i)
pressed_val.append(i)

updates = [gr.update(visible = False if j in pressed_val else True) for j in range(5)]

return [pressed_val] + updates

def press_btn_1(pressed_val):
return press_button(0, pressed_val)

def press_btn_2(pressed_val):
return press_button(1, pressed_val)

def press_btn_3(pressed_val):
return press_button(2, pressed_val)

def press_btn_4(pressed_val):
return press_button(3, pressed_val)

def press_btn_5(pressed_val):
return press_button(4, pressed_val)

completions[0].click(
press_btn_1,
inputs = [pressed],
outputs = [pressed] + completions
)

completions[1].click(
press_btn_2,
inputs = [pressed],
outputs = [pressed] + completions
)

completions[2].click(
press_btn_3,
inputs = [pressed],
outputs = [pressed] + completions
)

completions[3].click(
press_btn_4,
inputs = [pressed],
outputs = [pressed] + completions
)

completions[4].click(
press_btn_5,
inputs = [pressed],
outputs = [pressed] + completions
)

# When submit is pressed, run response, reset state, and set all buttons to visible

self.wrap_event(submit.click)(
self.response, inputs = [pressed], outputs = [pressed, query] + completions
)

return [pressed, query] + completions

def receive(self, *inp):
_, task, pressed_vals = inp
task.rankings = pressed_vals

return task

def present(self, task):
data : LMGenerationElement = task.data

updates = [gr.update(value = data.completions[i], visible = True) for i in range(5)]
return [[], data.query] + updates

if __name__ == "__main__":
write_path = "./rankings_dataset"
cheese = CHEESE(
LMPipeline,
LMFront,
pipeline_kwargs = {
"iterator" : make_iter(),
"write_path" : write_path,
"max_length" : 20,
"buffer_size" : 20,
"batch_size" : 20,
"force_new" : True,
"log_progress" : True
},
gradio = True
)

print(cheese.launch())

print(cheese.create_client(1))

0 comments on commit 51517fd

Please sign in to comment.