diff --git a/cheese/__init__.py b/cheese/__init__.py new file mode 100644 index 0000000..3401503 --- /dev/null +++ b/cheese/__init__.py @@ -0,0 +1,291 @@ +from typing import ClassVar, Iterable, Tuple, Dict, Any, Callable + +from cheese.client import ClientManager, ClientStatistics +from cheese.client.gradio_client import GradioClientManager +from cheese.pipeline import Pipeline +from cheese.models import BaseModel + +import cheese.utils.msg_constants as msg_constants +from cheese.utils.rabbit_utils import rabbitmq_callback + +import pickle +from b_rabbit import BRabbit +from tqdm import tqdm +import time + +# Master object for CHEESE +class CHEESE: + """ + Main object to use for running tasks with CHEESE + + :param pipeline_cls: Class for pipeline + :type pipeline_cls: Callable[, Pipeline] + + :param client_cls: Class for client + :type client_cls: Callable[,GradioFront] if gradio + + :param model_cls: Class for model + :type model_cls: Callable[,BaseModel] + + :param pipeline_kwargs: Additional keyword arguments to pass to pipeline constructor + :type pipeline_kwargs: Dict[str, Any] + + :param gradio: Whether to use gradio or custom frontend + :type gradio: bool + + :param draw_always: If true, doesn't check for free clients before drawing a task. + This is useful if you are trying to feed data directly to model and don't need to worry about having free clients. + :type draw_always: bool + + :param host: Host for rabbitmq server. Normally just locahost if you are running locally + :type host: str + + :param port: Port to run rabbitmq server on + :type port: int + """ + def __init__( + self, + pipeline_cls = None, client_cls = None, model_cls = None, + pipeline_kwargs : Dict[str, Any] = {}, model_kwargs : Dict[str, Any] = {}, + gradio : bool = True, draw_always : bool = False, + host : str = 'localhost', port : int = 5672 + ): + + self.gradio = gradio + self.draw_always = draw_always + + # Initialize rabbit MQ server + self.connection = BRabbit(host=host, port=port) + + # Channel for client to notify of task completion + self.subscriber = self.connection.EventSubscriber( + b_rabbit = self.connection, + routing_key = 'main', + publisher_name = 'client', + event_listener = self.client_ping + ) + + # Receive tasks via API + self.api_subscriber = self.connection.EventSubscriber( + b_rabbit = self.connection, + routing_key = 'main', + publisher_name = 'api', + event_listener = self.api_ping + ) + + # Send data back through API + self.api_publisher = self.connection.EventPublisher( + b_rabbit = self.connection, + publisher_name = 'main' + ) + + self.subscriber.subscribe_on_thread() + self.api_subscriber.subscribe_on_thread() + + # components initialized + self.pipeline : Pipeline = pipeline_cls(**pipeline_kwargs) + self.model : BaseModel = model_cls(**model_kwargs) if model_cls is not None else None + + self.client_cls = client_cls + if gradio: + self.client_manager = GradioClientManager() + else: + self.client_manager = ClientManager() + + self.pipeline.init_connection(self.connection) + self.client_manager.init_connection(self.connection) + if self.model is not None: self.model.init_connection(self.connection) + + self.clients = 0 + self.busy_clients = 0 + + self.finished = False # For when pipeline is exhausted + self.launched = False + + # Communication with API + self.receive_buffer = [] + + self.url = None + + def launch(self) -> str: + """ + Launch the frontend and return URL for users to access it. + """ + if not self.launched: + url = self.client_manager.init_front(self.client_cls) + else: + raise Exception("CHEESE has already been launched") + + self.launched = True + self.url = url + return url + + def start_listening(self, verbose : bool = True, listen_every : float = 1.0): + """ + If using as a server, call this before running client. + + :param verbose: Whether to print status updates + :type verbose: bool + + :param run_every: Listen for messages every x seconds + """ + + def send(msg : Any): + self.api_publisher.publish('api', pickle.dumps(msg)) + + while True: + if self.receive_buffer: + if verbose: + print("Received a message", self.receive_buffer[0]) + msg = self.receive_buffer.pop(0).split("|") + if msg[0] == msg_constants.READY: + send(True) + elif msg[0] == msg_constants.LAUNCH: + send(self.launch()) + elif msg[0] == msg_constants.ADD: + send(self.create_client(int(msg[1]))) + elif msg[0] == msg_constants.REMOVE: + self.remove_client(int(msg[1])) + elif msg[0] == msg_constants.STATS: + send(self.get_stats()) + elif msg[0] == msg_constants.DRAW: + self.draw() + else: + print("Warning: Unknown message received", msg) + time.sleep(listen_every) + + @rabbitmq_callback + def api_ping(self, msg): + """ + All API calls are routed through this method. Message is parsed to execute some function. + """ + # Needs: + # - ready + # - launch + # - Create client + # - Remove client + # - Get stats + # - draw + + try: + self.receive_buffer.append(pickle.loads(msg)) + except Exception as e: + # Check if the error has to do with receive_buffer not being defined yet + if "receive_buffer" in str(e): + print("Warning: RabbitMQ queue non-empty at startup. Consider restarting RabbitMQ server if unexpected errors arise.") + + + @rabbitmq_callback + def client_ping(self, msg): + """ + Method for ClientManager to ping the API when it needs more tasks or has taken a task + """ + msg = msg.decode('utf-8') + if msg == msg_constants.SENT: + # Client sent task to pipeline, needs a new one + self.busy_clients -= 1 + self.draw() + elif msg == msg_constants.RECEIVED: + self.busy_clients += 1 + else: + raise Exception("Error: Client pinged master with unknown message") + + def create_client(self, id : int) -> Tuple[int, int]: + """ + Create a client instance with given id. + + :param id: A unique identifying number for the client. + :type id: int + + :return: Username and password user can use to log in to CHEESE + """ + + id, pwd = self.client_manager.add_client(id) + self.clients += 1 + self.draw() # pre-emptively draw a task for the client to pick up + return id, pwd + + def remove_client(self, id : int): + """ + Remove client with given id. + + :param id: A unique identifying number for the client. + :type id: int + """ + self.client_manager.remove_client(id) + self.clients -= 1 + + def get_stats(self) -> Dict: + """ + Get various statistics in the form of a dictionary. + + :return: Dictionary containing following statistics and values + - url: URL for accessing CHEESE frontend + - finished: Whether pipeline is exhausted + - num_clients: Number of clients connected to CHEESE + - num_busy_clients: Number of clients currently working on a task + - num_tasks: Number of tasks completed overall + - client_stats: Dictionary of client statistics + - model_stats: Dictionary of model statistics + - pipeline_stats: Dictionary of pipeline statistics + """ + client_stats = self.client_manager.client_statistics + + # Get num_tasks from all clients + num_tasks = 0 + for client in client_stats: + stat : ClientStatistics = client_stats[client] + num_tasks += stat.total_tasks + + return { + 'url' : self.url, + 'finished' : self.finished, + 'num_clients' : self.clients, + 'num_busy_clients' : self.busy_clients, + 'num_tasks' : num_tasks, + 'client_stats' : client_stats, + 'model_stats' : self.model.get_stats() if self.model else None, + 'pipeline_stats' : self.pipeline.get_stats() + } + + def draw(self): + """ + Draws a sample from data pipeline and creates a task to send to clients. Does nothing if no free clients. + This check if overriden if draw_always is set to True. + """ + + if not self.draw_always and self.busy_clients >= self.clients: + return + + exhausted = not self.pipeline.queue_task() + + if exhausted and self.pipeline.exhausted(): + # finished, so we can stop + self.finished = True + + + def progress_bar(self, max_tasks : int, access_stat : Callable, call_every : Callable = None, check_every : float = 1.0): + """ + This function shows a progress bar via tqdm some given stat. Blocks execution. + Not recommended for interactive use. + + :param max_tasks: The maximum number of tasks to show progress to before returning + :type max_tasks: int + + :param access_stat: Some callable that returns a stat we want to see progress for (i.e. as an integer). + :type access_stat: Callable[, int] + + :param call_every: Some callable that we want to call every time stat is updated. + :type call_every: Callable[, None] + + :param check_every: How often to check for updates to the stat in seconds. + :type check_every: float + """ + + for i in tqdm(range(max_tasks)): + current_stat = access_stat() + while True: + if call_every: call_every() + if current_stat != access_stat(): + break + time.sleep(check_every) diff --git a/cheese/api/__init__.py b/cheese/api/__init__.py index 89a63f4..a868b6c 100644 --- a/cheese/api/__init__.py +++ b/cheese/api/__init__.py @@ -8,97 +8,93 @@ import cheese.utils.msg_constants as msg_constants from cheese.utils.rabbit_utils import rabbitmq_callback +import pickle from b_rabbit import BRabbit from tqdm import tqdm import time # Master object for CHEESE -class CHEESE: +class CHEESEAPI: """ - Main object to use for running tasks with CHEESE + API to access CHEESE master object. Assumes - :param pipeline_cls: Class for pipeline - :type pipeline_cls: Callable[, Pipeline] + :param host: Host for rabbitmq server. Normally just locahost if you are running locally + :type host: str - :param client_cls: Class for client - :type client_cls: Callable[,GradioFront] if gradio + :param port: Port to run rabbitmq server on + :type port: int - :param model_cls: Class for model - :type model_cls: Callable[,BaseModel] - - :param pipeline_kwargs: Additional keyword arguments to pass to pipeline constructor - :type pipeline_kwargs: Dict[str, Any] - - :param gradio: Whether to use gradio or custom frontend - :type gradio: bool - - :param draw_always: If true, doesn't check for free clients before drawing a task. - This is useful if you are trying to feed data directly to model and don't need to worry about having free clients. - :type draw_always: bool + :param timeout: Timeout for waiting for main server to respond + :type timeout: float """ - def __init__( - self, - pipeline_cls, client_cls = None, model_cls = None, - pipeline_kwargs : Dict[str, Any] = {}, model_kwargs : Dict[str, Any] = {}, - gradio : bool = True, draw_always : bool = False - ): - - self.gradio = gradio - self.draw_always = draw_always + def __init__(self, host : str = 'localhost', port : int = 5672, timeout : float = 10): + self.timeout = timeout # Initialize rabbit MQ server - self.connection = BRabbit(host='localhost', port=5672) + self.connection = BRabbit(host=host, port=port) - # Channel for client to notify of task completion + # Channel to get results back from main server self.subscriber = self.connection.EventSubscriber( b_rabbit = self.connection, - routing_key = 'main', - publisher_name = 'client', - event_listener = self.client_ping + routing_key = 'api', + publisher_name = 'main', + event_listener = self.main_listener ) self.subscriber.subscribe_on_thread() - - # API components initialized - self.pipeline : Pipeline = pipeline_cls(**pipeline_kwargs) - self.model : BaseModel = model_cls(**model_kwargs) if model_cls is not None else None - self.client_cls = client_cls - if gradio: - self.client_manager = GradioClientManager() - else: - self.client_manager = ClientManager() + # Channel to send commands to main server + self.publisher = self.connection.EventPublisher( + b_rabbit = self.connection, + publisher_name = 'api' + ) - self.pipeline.init_connection(self.connection) - self.client_manager.init_connection(self.connection) - if self.model is not None: self.model.init_connection(self.connection) + # Any received values from main will be placed here + self.buffer : Any = None - self.clients = 0 - self.busy_clients = 0 + # Check if main server is running + self.connected_to_main : bool = False + self.publisher.publish('main', pickle.dumps(msg_constants.READY)) + self.connected_to_main = True - self.finished = False # For when pipeline is exhausted + if not self.await_result(): + raise Exception("Main server not running") - def launch(self) -> str: + @rabbitmq_callback + def main_listener(self, msg : str): """ - Launch the frontend and return URL for users to access it. + Callback for main server. Receives messages from main server and places them in buffer. """ - return self.client_manager.init_front(self.client_cls) + if not self.connected_to_main: + print("Warning: RabbitMQ queue non-empty at startup. Consider restarting RabbitMQ server if unexpected errors arise.") + return + msg = pickle.loads(msg) + self.buffer = msg - @rabbitmq_callback - def client_ping(self, msg): + def await_result(self, time_step : float = 0.5): """ - Method for ClientManager to ping the API when it needs more tasks or has taken a task + Assuming buffer is none """ - msg = msg.decode('utf-8') - if msg == msg_constants.SENT: - # Client sent task to pipeline, needs a new one - self.busy_clients -= 1 - self.draw() - elif msg == msg_constants.RECEIVED: - self.busy_clients += 1 - else: - raise Exception("Error: Client pinged master with unknown message") + total_time = 0 + while self.buffer is None: + time.sleep(time_step) + total_time += time_step + if total_time > self.timeout: + print("Warning: Timeout exceeded awaiting API result.") + return None + + res = self.buffer + self.buffer = None + return res + + def launch(self) -> str: + """ + Launch the frontend and return URL for users to access it. + """ + self.publisher.publish('main', pickle.dumps(msg_constants.LAUNCH)) + return self.await_result() + def create_client(self, id : int) -> Tuple[int, int]: """ Create a client instance with given id. @@ -108,11 +104,10 @@ def create_client(self, id : int) -> Tuple[int, int]: :return: Username and password user can use to log in to CHEESE """ + msg = f"{msg_constants.ADD}|{id}" + self.publisher.publish('main', pickle.dumps(msg)) - id, pwd = self.client_manager.add_client(id) - self.clients += 1 - self.draw() # pre-emptively draw a task for the client to pick up - return id, pwd + return self.await_result() def remove_client(self, id : int): """ @@ -121,8 +116,10 @@ def remove_client(self, id : int): :param id: A unique identifying number for the client. :type id: int """ - self.client_manager.remove_client(id) - self.clients -= 1 + msg = f"{msg_constants.REMOVE}|{id}" + self.publisher.publish('main', pickle.dumps(msg)) + + return self.await_result() def get_stats(self) -> Dict: """ @@ -136,38 +133,16 @@ def get_stats(self) -> Dict: - model_stats: Dictionary of model statistics - pipeline_stats: Dictionary of pipeline statistics """ - client_stats = self.client_manager.client_statistics + self.publisher.publish('main', pickle.dumps(msg_constants.STATS)) - # Get num_tasks from all clients - num_tasks = 0 - for client in client_stats: - stat : ClientStatistics = client_stats[client] - num_tasks += stat.total_tasks - - return { - 'num_clients' : self.clients, - 'num_busy_clients' : self.busy_clients, - 'num_tasks' : num_tasks, - 'client_stats' : client_stats, - 'model_stats' : self.model.get_stats() if self.model else None, - 'pipeline_stats' : self.pipeline.get_stats() - } + return self.await_result() def draw(self): """ Draws a sample from data pipeline and creates a task to send to clients. Does nothing if no free clients. This check if overriden if draw_always is set to True. """ - - if not self.draw_always and self.busy_clients >= self.clients: - return - - exhausted = not self.pipeline.queue_task() - - if exhausted and self.pipeline.exhausted(): - # finished, so we can stop - self.finished = True - + self.publisher.publish('main', pickle.dumps(msg_constants.DRAW)) def progress_bar(self, max_tasks : int, access_stat : Callable, call_every : Callable = None, check_every : float = 1.0): """ diff --git a/cheese/pipeline/datasets.py b/cheese/pipeline/datasets.py index 53d5724..16629a9 100644 --- a/cheese/pipeline/datasets.py +++ b/cheese/pipeline/datasets.py @@ -4,16 +4,49 @@ from datasets import Dataset from cheese.pipeline import Pipeline +from cheese.utils import safe_mkdir + +import pandas as pd class DatasetPipeline(Pipeline): """ Base class for any pipeline thats data destination is a datasets.Dataset object + + :param format: Format to save result dataset to. Defaults to arrow. Can be arrow or csv. + :type format: str + + :param save_every: Save dataset whenever this number of rows is added. + :type save_every: int """ - def __init__(self): + def __init__(self, format : str = "csv", save_every : int = 1): super().__init__() self.write_path : str = None self.res_dataset : Dataset = None + self.format = format + + self.save_every = save_every + self.save_accum = 0 + + def load_dataset(self) -> bool: + """ + Loads the results dataset from a given path. Returns false if load fails. Assumes write_path has been set already. + + :return: Whether load was successful + :rtype: bool + """ + + if self.write_path is None: + raise Exception("Error: Attempted to load results dataset without ever specifiying a path to write it to") + + try: + if self.format == "arrow": + self.res_dataset = Dataset.load_from_disk(self.write_path) + elif self.format == "csv": + self.res_dataset = pd.read_csv(self.write_path) + return True + except: + return False def save_dataset(self): """ @@ -23,9 +56,12 @@ def save_dataset(self): if self.res_dataset is None: return if self.write_path is None: - raise Exception("Error: Attempted to save result dataset without every specifiying a path to write to") + raise Exception("Error: Attempted to save result dataset without ever specifiying a path to write to") - self.res_dataset.save_to_disk(self.write_path) + if self.format == "arrow": + self.res_dataset.save_to_disk(self.write_path) + elif self.format == "csv": + self.res_dataset.to_csv(self.write_path, index = False) def add_row_to_dataset(self, row : Dict[str, Any]): """ @@ -34,10 +70,18 @@ def add_row_to_dataset(self, row : Dict[str, Any]): :param row: The row, as a dictionary, to add to the result dataset :type row: Dict[str, Any] """ + row = {key : [row[key]] for key in row} if self.res_dataset is None: - row = {key : [row[key]] for key in row} - self.res_dataset = Dataset.from_dict(row) + self.res_dataset = Dataset.from_dict(row) if self.format == "arrow" else pd.DataFrame(row) else: - self.res_dataset = self.res_dataset.add_item(row) - self.save_dataset() + if self.format == "arrow": + self.res_dataset = self.res_dataset.append(row) + else: + new_df = pd.DataFrame(row) + self.res_dataset = pd.concat([self.res_dataset, new_df], ignore_index = True) + + self.save_accum += 1 + if self.save_accum >= self.save_every: + self.save_dataset() + self.save_accum = 0 diff --git a/cheese/pipeline/generative.py b/cheese/pipeline/generative.py index 5ef1b6a..b825fed 100644 --- a/cheese/pipeline/generative.py +++ b/cheese/pipeline/generative.py @@ -6,6 +6,7 @@ import threading import joblib +import time class GenerativePipeline(WriteOnlyPipeline): """ @@ -120,7 +121,11 @@ def populate_buffer(self): def fetch(self) -> BatchElement: if len(self.buffer) == 0: - raise Exception("Error: Tried to fetch data before any was created. Please wait longer for buffer to fill or increase its capacity.") + # Stall until buffer is ready + print("Warning: Tried to fetch data before any was created. Please wait longer for buffer to fill or increase its capacity. Execution will now stall until buffer is ready.") + while len(self.buffer) == 0: + time.sleep(1) + elem = self.buffer.pop(0) return elem diff --git a/cheese/pipeline/iterable_dataset.py b/cheese/pipeline/iterable_dataset.py index 1e80c71..15aa589 100644 --- a/cheese/pipeline/iterable_dataset.py +++ b/cheese/pipeline/iterable_dataset.py @@ -27,8 +27,8 @@ class IterablePipeline(DatasetPipeline): :param max_length: Maximum number of entries to produce for output dataset. Defaults to infinity. """ - def __init__(self, iter : Iterable, write_path : str, force_new : bool = False, max_length = np.inf): - super().__init__() + def __init__(self, iter : Iterable, write_path : str, force_new : bool = False, max_length = np.inf, **kwargs): + super().__init__(**kwargs) self.data_source = iter self.iter_steps = 0 # How many steps through iterator have been taken (counting bad data) @@ -41,7 +41,7 @@ def __init__(self, iter : Iterable, write_path : str, force_new : bool = False, try: assert not force_new - self.res_dataset = load_from_disk(write_path) + assert self.load_dataset() self.iter_steps, self.progress = joblib.load("save_data/progress.joblib") for _ in range(self.iter_steps): next(self.data_source) diff --git a/cheese/pipeline/text_captions.py b/cheese/pipeline/text_captions.py index 9cb51e7..13fc436 100644 --- a/cheese/pipeline/text_captions.py +++ b/cheese/pipeline/text_captions.py @@ -27,7 +27,7 @@ def __init__(self, read_path : str, write_path : str, force_new : bool = False): # Captioned dataset try: assert not force_new - self.res_dataset = load_from_disk(write_path) + assert self.load_dataset() self.finished_items = len(self.res_dataset["text"]) except: # intialize empty dataset as pandas df, then convert to dataset diff --git a/cheese/pipeline/wav_folder.py b/cheese/pipeline/wav_folder.py index 43e0257..3a9fc30 100644 --- a/cheese/pipeline/wav_folder.py +++ b/cheese/pipeline/wav_folder.py @@ -42,7 +42,7 @@ def __init__(self, read_path : str, write_path : str, force_new : bool = False): self.total_items = len(os.listdir(self.read_path)) try: assert not force_new - self.res_dataset = load_from_disk(write_path) + assert self.load_dataset() self.index_book = joblib.load("save_data/index_book.joblib") except: diff --git a/cheese/pipeline/write_only.py b/cheese/pipeline/write_only.py index 25a666a..ebdc15e 100644 --- a/cheese/pipeline/write_only.py +++ b/cheese/pipeline/write_only.py @@ -19,15 +19,16 @@ class WriteOnlyPipeline(DatasetPipeline): :param force_new: Whether to force a new dataset to be created, even if one already exists at the write path :type force_new: bool """ - def __init__(self, write_path : str, force_new : bool = False): - super().__init__() + def __init__(self, write_path : str, force_new : bool = False, **kwargs): + super().__init__(**kwargs) self.write_path = write_path self.force_new = force_new try: assert not force_new - self.res_dataset = load_from_disk(write_path) + assert self.load_dataset() + print(f"Succesfully loaded dataset with {len(self.res_dataset)} entries.") except: pass diff --git a/cheese/utils/msg_constants.py b/cheese/utils/msg_constants.py index 202f5db..2c2869d 100644 --- a/cheese/utils/msg_constants.py +++ b/cheese/utils/msg_constants.py @@ -1,3 +1,11 @@ # Constants for basic messages SENT = "0" -RECEIVED = "1" \ No newline at end of file +RECEIVED = "1" + +# Constants for API commands +READY = "ready" +LAUNCH = "launch" +ADD = "add" +REMOVE = "remove" +STATS = "stats" +DRAW = "draw" \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index bc4bb8b..543364d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,3 +5,4 @@ b_rabbit datasets webdataset joblib +altair \ No newline at end of file diff --git a/docs/source/cheese/api.rst b/docs/source/cheese/api.rst new file mode 100644 index 0000000..327b394 --- /dev/null +++ b/docs/source/cheese/api.rst @@ -0,0 +1,9 @@ +.. _api: + +CHEESE API +**************** + +The API is used to access the server from other scripts/applications. + +.. autoclass:: cheese.api.CHEESEAPI + :members: \ No newline at end of file diff --git a/docs/source/backend/api.rst b/docs/source/cheese/cheese.rst similarity index 79% rename from docs/source/backend/api.rst rename to docs/source/cheese/cheese.rst index 2776269..f007140 100644 --- a/docs/source/backend/api.rst +++ b/docs/source/cheese/cheese.rst @@ -1,10 +1,10 @@ -.. _api: +.. _cheese: -CHEESE API +CHEESE *********** Once you have a pipeline, data, and a frontend set up, running CHEESE (and adding or removing clients) is simple! You just launch CHEESE and add as many clients as you wish. -.. autoclass:: cheese.api.CHEESE +.. autoclass:: cheese.CHEESE :members: diff --git a/docs/source/backend/data.rst b/docs/source/cheese/data.rst similarity index 100% rename from docs/source/backend/data.rst rename to docs/source/cheese/data.rst diff --git a/docs/source/backend/gradio_client.rst b/docs/source/cheese/gradio_client.rst similarity index 100% rename from docs/source/backend/gradio_client.rst rename to docs/source/cheese/gradio_client.rst diff --git a/docs/source/backend/model.rst b/docs/source/cheese/model.rst similarity index 100% rename from docs/source/backend/model.rst rename to docs/source/cheese/model.rst diff --git a/docs/source/backend/pipeline.rst b/docs/source/cheese/pipeline.rst similarity index 100% rename from docs/source/backend/pipeline.rst rename to docs/source/cheese/pipeline.rst diff --git a/docs/source/customtask.rst b/docs/source/customtask.rst index 9e58a8a..868200e 100644 --- a/docs/source/customtask.rst +++ b/docs/source/customtask.rst @@ -107,7 +107,7 @@ will run the experiment on two strings and post the results to a folder called s .. code-block:: python - from cheese.api import CHEESE + from cheese import CHEESE import time data = ["The goose went to the store and was very happy", "The goose went to the store and was very sad"] @@ -189,3 +189,41 @@ to construction of the CHEESE object is specifiying a model class. "max_length" : 5 } ) + + # For the API to function. More information below. + cheese.start_listening() + +Finally, what if we wanted to access results from another script or machine? Generally, you can +create the API object without specifiying a host address and port. However, +if you need to change this, you can simply pass your desired host and port +to both the server constructor and the api constructor. The below example +shows the default values for both. Be sure to call `cheese.start_listening()` on the +server object before constructing the API object, as it will rely on this +to make the initial connection. + +.. code-block:: python + + from cheese.api import CHEESEAPI + api = CHEESEAPI( + timeout = 10, + host = 'localhost', + port = 5672 + ) + + # Can now use API as we'd expect + + # Trying to launch when already launched will cause an error + # So ensure that the server did not call launch beforehand + api.launch() + stats = api.get_stats() + + # If you need the URL after launching, you can access it from stats + url = stats["url"] + + usr, passwd = api.create_client(1) + + while True: + time.sleep(10) + stats = api.get_stats() + if stats["finished"]: + break diff --git a/docs/source/faq.rst b/docs/source/faq.rst index 8669892..d32c1d4 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -18,4 +18,10 @@ Q: How do I run CHEESE? A: You can refer to :ref:`getting started ` to get CHEESE running. For more info on creating your own tasks in CHEESE you can refer to :ref:`custom tasks in gradio `. +Q: Does the server have to be run from my application? + +A: Nope! There are many use cases in which you may want to run CHEESE separately from your application that you wish to connect to CHEESE. +This is the purpose of :code:`cheese.api`. You can the server as you would normally, then you can call upon +the :code:`cheese.api.CHEESEAPI` object to use any of the functionality of :code:`cheese.CHEESE`. + diff --git a/docs/source/index.rst b/docs/source/index.rst index b1039cd..92c8d64 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -16,11 +16,12 @@ CHEESE is a Co-adaptive Harness for Effective Evaluation, Steering and Enhanceme :maxdepth: 1 :caption: The Backend - backend/api - backend/data - backend/pipeline - backend/model - backend/gradio_client + cheese/api + cheese/cheese + cheese/data + cheese/pipeline + cheese/model + cheese/gradio_client Indices and tables ================== diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/audio_test.py b/examples/audio_test.py index 3d57633..37c0239 100644 --- a/examples/audio_test.py +++ b/examples/audio_test.py @@ -2,7 +2,7 @@ from cheese.data import BatchElement from cheese.client.gradio_client import GradioFront -from cheese.api import CHEESE +from cheese import CHEESE from datasets import Dataset import gradio as gr diff --git a/examples/docs_example.py b/examples/docs_example.py index 7f7fc6c..ef7cad2 100644 --- a/examples/docs_example.py +++ b/examples/docs_example.py @@ -81,7 +81,7 @@ def present(self, task): data : SentimentElement = task.data return [data.text] # Return list for gradio outputs -from cheese.api import CHEESE +from cheese import CHEESE import time data = ["The goose went to the store and was very happy", "The goose went to the store and was very sad"] diff --git a/examples/image_selection.py b/examples/image_selection.py index 4b4d5e1..8c76701 100644 --- a/examples/image_selection.py +++ b/examples/image_selection.py @@ -3,7 +3,7 @@ from cheese.data import BatchElement from cheese.client.gradio_client import GradioFront -from cheese.api import CHEESE +from cheese import CHEESE from dataclasses import dataclass @@ -57,7 +57,7 @@ def post(self, be : ImageSelectionBatchElement): """ Post takes a finished (labelled) batch element and posts it to result dataset. """ - row = {"img1_url" : be.img1_url, "img2_url" : be.img2_url, "selection" : be.select, "time" : be.time} + row = {"img1_url" : be.img1_url, "img2_url" : be.img2_url, "select" : be.select, "time" : be.time} # IterablePipeline.post_row(...) takes a dict and adds it as a row to end of the result dataset # It also saves the result dataset and updates progress (in most cases it should always be called in post) # We check for bad data and avoid it @@ -157,6 +157,7 @@ def present(self, task): "iter" : make_iter(), "write_path" : "./img_dataset_res", "force_new" : True, "max_length" : 5 } ) + print(cheese.launch()) print(cheese.create_client(15)) diff --git a/examples/instruct_hf_pipeline.py b/examples/instruct_hf_pipeline.py new file mode 100644 index 0000000..b002572 --- /dev/null +++ b/examples/instruct_hf_pipeline.py @@ -0,0 +1,197 @@ +""" + This example does an instruct type annotation task in which labellers are given + multiple prompt completions and asked to rank them in order of preference. + It collects and write a dataset of preferences for completions. +""" + +from dataclasses import dataclass +from typing import List, Iterable + +from transformers import pipeline +import gradio as gr +from cheese.pipeline.generative import GenerativePipeline +from cheese.models import BaseModel +from cheese.data import BatchElement +from cheese.client.gradio_client import GradioFront +from cheese import CHEESE + +@dataclass +class LMGenerationElement(BatchElement): + query : str = None + completions : List[str] = None + rankings : List[int] = None # Ordering for the completions w.r.t indices + +class LMPipeline(GenerativePipeline): + def __init__(self, n_samples = 5, **kwargs): + super().__init__(**kwargs) + + self.n_samples = n_samples + self.pipe = pipeline(task="text-generation", model = 'gpt2', device=0) + self.pipe.tokenizer.pad_token_id = self.pipe.model.config.eos_token_id + # prevents annoying messages + + + self.init_buffer() + + def generate(self, model_input : Iterable[str]) -> List[LMGenerationElement]: + """ + Generates a batch of elements using the pipeline's iterator. + """ + print("Generate called") + elements = [] + for i in range(self.batch_size): + query = model_input[i] + completions = self.pipe(query, max_length=100, num_return_sequences=self.n_samples) + completions = [completion["generated_text"] for completion in completions] + elements.append(LMGenerationElement(query=query, completions=completions)) + return elements + + def extract_data(self, batch_element : LMGenerationElement) -> dict: + """ + Extracts data from a batch element. + """ + return { + "query" : batch_element.query, + "completions" : batch_element.completions, + "rankings" : batch_element.rankings + } + +def make_iter(length : int = 20): + print("Creating prompt iterator...") + pipe = pipeline(task="text-generation", model = 'gpt2', device=0) + pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id + chunk_size = 16 + meta_prompt = f"As an example, below is a list of {chunk_size + 3} prompts you could feed to a language model:\n"+\ + "\"What is the capital of France?\"\n"+\ + "\"Write a story about geese\"\n"+\ + "\"Tell me a fun fact about rabbits\"\n" + + def extract_prompts(entire_generation : str): + generation = entire_generation[len(meta_prompt):] + prompts = generation.split("\n") + prompts = [prompt[1:-1] for prompt in prompts] # Remove quotes + return prompts[:chunk_size] + + prompt_buffer = [] + + while len(prompt_buffer) < length: + prompts = pipe(meta_prompt, max_length=128, num_return_sequences=chunk_size) + prompts = sum([extract_prompts(prompt["generated_text"]) for prompt in prompts], []) + prompt_buffer += prompts + + del pipe + + return iter(prompt_buffer) + +class LMFront(GradioFront): + def main(self): + pressed = gr.State([]) + with gr.Column(): + gr.Button("On the left you will see a prompt. On the right you will "+ \ + "see various possible completions. Select the completions in order of "+ \ + "best to worst", interactive = False, show_label = False) + with gr.Row(): + query = gr.Textbox("Prompt", interactive = False, show_label = False) + with gr.Column(): + gr.Textbox("Completions:", interactive = False, show_label = False) + + completions = [gr.Button("", interactive = True) for _ in range(5)] + + + submit = gr.Button("Submit") + + # When a button is pressed, append index to state, and make button not visible + + def press_button(i, pressed_val): + print("Pressed button", i) + pressed_val.append(i) + + updates = [gr.update(visible = False if j in pressed_val else True) for j in range(5)] + + return [pressed_val] + updates + + def press_btn_1(pressed_val): + return press_button(0, pressed_val) + + def press_btn_2(pressed_val): + return press_button(1, pressed_val) + + def press_btn_3(pressed_val): + return press_button(2, pressed_val) + + def press_btn_4(pressed_val): + return press_button(3, pressed_val) + + def press_btn_5(pressed_val): + return press_button(4, pressed_val) + + completions[0].click( + press_btn_1, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[1].click( + press_btn_2, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[2].click( + press_btn_3, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[3].click( + press_btn_4, + inputs = [pressed], + outputs = [pressed] + completions + ) + + completions[4].click( + press_btn_5, + inputs = [pressed], + outputs = [pressed] + completions + ) + + # When submit is pressed, run response, reset state, and set all buttons to visible + + self.wrap_event(submit.click)( + self.response, inputs = [pressed], outputs = [pressed, query] + completions + ) + + return [pressed, query] + completions + + def receive(self, *inp): + _, task, pressed_vals = inp + task.rankings = pressed_vals + + return task + + def present(self, task): + data : LMGenerationElement = task.data + + updates = [gr.update(value = data.completions[i], visible = True) for i in range(5)] + return [[], data.query] + updates + +if __name__ == "__main__": + write_path = "./rankings_dataset" + cheese = CHEESE( + LMPipeline, + LMFront, + pipeline_kwargs = { + "iterator" : make_iter(), + "write_path" : write_path, + "max_length" : 20, + "buffer_size" : 20, + "batch_size" : 20, + "force_new" : True, + "log_progress" : True + }, + gradio = True + ) + + print(cheese.launch()) + + print(cheese.create_client(1)) \ No newline at end of file diff --git a/examples/server/generic_client.py b/examples/server/generic_client.py new file mode 100644 index 0000000..2aefbd1 --- /dev/null +++ b/examples/server/generic_client.py @@ -0,0 +1,14 @@ +from cheese.api import CHEESEAPI +import time + +if __name__ == "__main__": + api = CHEESEAPI(timeout = 10) + print("Connected to server") + print(api.launch()) + time.sleep(1) + print("adding client") + print(api.create_client(1)) + + _ = input("Press enter to view stats") + + print(api.get_stats()["client_stats"]) \ No newline at end of file diff --git a/examples/server/image_selection_server.py b/examples/server/image_selection_server.py new file mode 100644 index 0000000..00afbe4 --- /dev/null +++ b/examples/server/image_selection_server.py @@ -0,0 +1,40 @@ +from atexit import register +from cheese.pipeline.iterable_dataset import IterablePipeline, InvalidDataException +from cheese.data import BatchElement +from cheese.client.gradio_client import GradioFront + +from cheese import CHEESE + +from dataclasses import dataclass + +from PIL import Image +from cheese.utils.img_utils import url2img + +import gradio as gr +import datasets +import time + +""" + In this example task, we present two images from the laion-art dataset to our labellers, + and have them select which one they prefer over the two. For the case in which an image + is not loading for them, they will be given an error button to specify they are not seeing any data. +""" + +from examples.image_selection import ( + ImageSelectionBatchElement, + ImageSelectionPipeline, + make_iter, + ImageSelectionFront, +) + +if __name__ == "__main__": + # The pipeline kwargs are inherited from IterablePipeline + cheese = CHEESE( + ImageSelectionPipeline, ImageSelectionFront, + pipeline_kwargs = { + "iter" : make_iter(), "write_path" : "./img_dataset_res", "force_new" : True, "max_length" : 5 + } + ) + + print("Waiting on client...") + cheese.start_listening(verbose = True) \ No newline at end of file diff --git a/examples/stablediffusion_ratings.py b/examples/stablediffusion_ratings.py index 5186dc7..3d7068d 100644 --- a/examples/stablediffusion_ratings.py +++ b/examples/stablediffusion_ratings.py @@ -5,7 +5,7 @@ from cheese.models import BaseModel from cheese.data import BatchElement from cheese.client.gradio_client import GradioFront -from cheese.api import CHEESE +from cheese import CHEESE from PIL import Image diff --git a/examples/test_api.py b/examples/test_api.py index cbc328e..2158565 100644 --- a/examples/test_api.py +++ b/examples/test_api.py @@ -1,6 +1,6 @@ from cheese.pipeline.text_captions import TextCaptionPipeline from cheese.client.text_captions import TextCaptionClient -from cheese.api import CHEESE +from cheese import CHEESE from datasets import load_from_disk import time diff --git a/examples/test_gradio_api.py b/examples/test_gradio_api.py index 8d60286..d50771b 100644 --- a/examples/test_gradio_api.py +++ b/examples/test_gradio_api.py @@ -1,6 +1,6 @@ from cheese.pipeline.text_captions import TextCaptionPipeline from cheese.client.gradio_text_captions import GradioTextCaptionClient -from cheese.api import CHEESE +from cheese import CHEESE from datasets import load_from_disk import time diff --git a/requirements.txt b/requirements.txt index dcf12ad..e9a98b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ markupsafe pandas joblib jinja2 +altair \ No newline at end of file diff --git a/tests/checkpoint_test.py b/tests/checkpoint_test.py new file mode 100644 index 0000000..53c5b07 --- /dev/null +++ b/tests/checkpoint_test.py @@ -0,0 +1,97 @@ +""" + This test script ensures that datasets can be saved, recovered and saved to again. +""" + +from examples.server.image_selection_server import * +from cheese import CHEESE +from dataclasses import dataclass +from datasets import Dataset +import pandas as pd + +from b_rabbit import BRabbit + +@dataclass +class ImageSelectionBatchElement(BatchElement): + img1_url : str = None + img2_url : str = None + select : int = 0 # 0 None, -1 Left, 1, Right + time : float = 0 # Time in seconds it took for user to select image + +if __name__ == "__main__": + cheese = CHEESE( + ImageSelectionPipeline, ImageSelectionFront, + pipeline_kwargs = { + "iter" : make_iter(), "write_path" : "./img_dataset_res", "force_new" : True, "max_length" : 5, "format" : "csv" + } + ) + + # Add something + data = [] + + data.append({ + "img1_url" : "a", "img2_url" : "b", + "select" : 1, "time" : 1 + }) + cheese.pipeline.post( + ImageSelectionBatchElement( + **data[0] + ) + ) + + data.append({ + "img1_url" : "c", "img2_url" : "d", + "select" : 1, "time" : 1 + }) + cheese.pipeline.post( + ImageSelectionBatchElement( + **data[1] + ) + ) + + del cheese + + # Check dataset to make sure the data was saved + # Assert each row is what we expect + dataset = pd.read_csv("./img_dataset_res") + for i in range(len(dataset)): + assert dataset.loc[i].to_dict() == data[i] + + cheese = CHEESE( + ImageSelectionPipeline, ImageSelectionFront, + pipeline_kwargs = { + "iter" : make_iter(), "write_path" : "./img_dataset_res", "force_new" : False, "max_length" : 5, "format" : "csv" + } + ) + + data.append({ + "img1_url" : "e", "img2_url" : "f", + "select" : 1, "time" : 1 + }) + cheese.pipeline.post( + ImageSelectionBatchElement( + **data[2] + ) + ) + + data.append({ + "img1_url" : "g", "img2_url" : "h", + "select" : 1, "time" : 1 + }) + cheese.pipeline.post( + ImageSelectionBatchElement( + **data[3] + ) + ) + + print("==== 2 ====") + for i in range(len(dataset)): + assert dataset.loc[i].to_dict() == data[i] + + dataset = pd.read_csv("./img_dataset_res") + + print(" ==== 3 ====") + for i in range(len(dataset)): + assert dataset.loc[i].to_dict() == data[i] + + print("All Tests Passed") +