From 51517fdb9e6c313252993f7453b51e6b849845da Mon Sep 17 00:00:00 2001 From: s Date: Tue, 13 Dec 2022 18:09:12 -0500 Subject: [PATCH] Add instruct example --- cheese/__init__.py | 2 + examples/instruct_hf_pipeline.py | 197 +++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 examples/instruct_hf_pipeline.py diff --git a/cheese/__init__.py b/cheese/__init__.py index 98eb1c2..3401503 100644 --- a/cheese/__init__.py +++ b/cheese/__init__.py @@ -150,6 +150,8 @@ def send(msg : Any): send(self.get_stats()) elif msg[0] == msg_constants.DRAW: self.draw() + else: + print("Warning: Unknown message received", msg) time.sleep(listen_every) @rabbitmq_callback diff --git a/examples/instruct_hf_pipeline.py b/examples/instruct_hf_pipeline.py new file mode 100644 index 0000000..b002572 --- /dev/null +++ b/examples/instruct_hf_pipeline.py @@ -0,0 +1,197 @@ +""" + This example does an instruct type annotation task in which labellers are given + multiple prompt completions and asked to rank them in order of preference. + It collects and write a dataset of preferences for completions. +""" + +from dataclasses import dataclass +from typing import List, Iterable + +from transformers import pipeline +import gradio as gr +from cheese.pipeline.generative import GenerativePipeline +from cheese.models import BaseModel +from cheese.data import BatchElement +from cheese.client.gradio_client import GradioFront +from cheese import CHEESE + +@dataclass +class LMGenerationElement(BatchElement): + query : str = None + completions : List[str] = None + rankings : List[int] = None # Ordering for the completions w.r.t indices + +class LMPipeline(GenerativePipeline): + def __init__(self, n_samples = 5, **kwargs): + super().__init__(**kwargs) + + self.n_samples = n_samples + self.pipe = pipeline(task="text-generation", model = 'gpt2', device=0) + self.pipe.tokenizer.pad_token_id = self.pipe.model.config.eos_token_id + # prevents annoying messages + + + self.init_buffer() + + def generate(self, model_input : Iterable[str]) -> List[LMGenerationElement]: + """ + Generates a batch of elements using the pipeline's iterator. + """ + print("Generate called") + elements = [] + for i in range(self.batch_size): + query = model_input[i] + completions = self.pipe(query, max_length=100, num_return_sequences=self.n_samples) + completions = [completion["generated_text"] for completion in completions] + elements.append(LMGenerationElement(query=query, completions=completions)) + return elements + + def extract_data(self, batch_element : LMGenerationElement) -> dict: + """ + Extracts data from a batch element. + """ + return { + "query" : batch_element.query, + "completions" : batch_element.completions, + "rankings" : batch_element.rankings + } + +def make_iter(length : int = 20): + print("Creating prompt iterator...") + pipe = pipeline(task="text-generation", model = 'gpt2', device=0) + pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id + chunk_size = 16 + meta_prompt = f"As an example, below is a list of {chunk_size + 3} prompts you could feed to a language model:\n"+\ + "\"What is the capital of France?\"\n"+\ + "\"Write a story about geese\"\n"+\ + "\"Tell me a fun fact about rabbits\"\n" + + def extract_prompts(entire_generation : str): + generation = entire_generation[len(meta_prompt):] + prompts = generation.split("\n") + prompts = [prompt[1:-1] for prompt in prompts] # Remove quotes + return prompts[:chunk_size] + + prompt_buffer = [] + + while len(prompt_buffer) < length: + prompts = pipe(meta_prompt, max_length=128, num_return_sequences=chunk_size) + prompts = sum([extract_prompts(prompt["generated_text"]) for prompt in prompts], []) + prompt_buffer += prompts + + del pipe + + return iter(prompt_buffer) + +class LMFront(GradioFront): + def main(self): + pressed = gr.State([]) + with gr.Column(): + gr.Button("On the left you will see a prompt. On the right you will "+ \ + "see various possible completions. Select the completions in order of "+ \ + "best to worst", interactive = False, show_label = False) + with gr.Row(): + query = gr.Textbox("Prompt", interactive = False, show_label = False) + with gr.Column(): + gr.Textbox("Completions:", interactive = False, show_label = False) + + completions = [gr.Button("", interactive = True) for _ in range(5)] + + + submit = gr.Button("Submit") + + # When a button is pressed, append index to state, and make button not visible + + def press_button(i, pressed_val): + print("Pressed button", i) + pressed_val.append(i) + + updates = [gr.update(visible = False if j in pressed_val else True) for j in range(5)] + + return [pressed_val] + updates + + def press_btn_1(pressed_val): + return press_button(0, pressed_val) + + def press_btn_2(pressed_val): + return press_button(1, pressed_val) + + def press_btn_3(pressed_val): + return press_button(2, pressed_val) + + def press_btn_4(pressed_val): + return press_button(3, pressed_val) + + def press_btn_5(pressed_val): + return press_button(4, pressed_val) + + completions[0].click( + press_btn_1, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[1].click( + press_btn_2, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[2].click( + press_btn_3, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[3].click( + press_btn_4, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[4].click( + press_btn_5, + inputs = [pressed], + outputs = [pressed] + completions + ) + + # When submit is pressed, run response, reset state, and set all buttons to visible + + self.wrap_event(submit.click)( + self.response, inputs = [pressed], outputs = [pressed, query] + completions + ) + + return [pressed, query] + completions + + def receive(self, *inp): + _, task, pressed_vals = inp + task.rankings = pressed_vals + + return task + + def present(self, task): + data : LMGenerationElement = task.data + + updates = [gr.update(value = data.completions[i], visible = True) for i in range(5)] + return [[], data.query] + updates + +if __name__ == "__main__": + write_path = "./rankings_dataset" + cheese = CHEESE( + LMPipeline, + LMFront, + pipeline_kwargs = { + "iterator" : make_iter(), + "write_path" : write_path, + "max_length" : 20, + "buffer_size" : 20, + "batch_size" : 20, + "force_new" : True, + "log_progress" : True + }, + gradio = True + ) + + print(cheese.launch()) + + print(cheese.create_client(1)) \ No newline at end of file