From 684cb1924e8be899354aafc64b145e0be5a49b84 Mon Sep 17 00:00:00 2001 From: Emerson Havener Date: Tue, 13 Dec 2022 16:02:37 -0800 Subject: [PATCH 1/4] update CHEESE import to support API separation update --- examples/architext/architext_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/architext/architext_test.py b/examples/architext/architext_test.py index 6c2c27c..9a6ced9 100644 --- a/examples/architext/architext_test.py +++ b/examples/architext/architext_test.py @@ -1,4 +1,4 @@ -from cheese.api import CHEESE +from cheese import CHEESE from cheese.client.gradio_client import GradioFront, InvalidInputException from cheese.data import BatchElement from cheese.models import BaseModel @@ -51,7 +51,7 @@ class ArchitextModel(BaseModel): def __init__(self): super().__init__() - self.model = AutoModelForCausalLM.from_pretrained("architext/gptj-162M") + self.model = AutoModelForCausalLM.from_pretrained("/Users/emerson/gptj-162M") self.tokenizer = AutoTokenizer.from_pretrained("gpt2") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -119,7 +119,7 @@ def receive(self, *inp): result_layout_with_removed.pop(removed_space_name, None) task.data.result['layout_after_removed'] = json.dumps(result_layout_with_removed) - + # print("task:", task) return task def present(self, task): From 0aa8c86c9111c5dab75f870b654617f2bcae35bc Mon Sep 17 00:00:00 2001 From: Emerson Havener Date: Tue, 13 Dec 2022 16:06:49 -0800 Subject: [PATCH 2/4] update CHEESE import to support API separation update --- examples/architext/architext_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/architext/architext_test.py b/examples/architext/architext_test.py index 9a6ced9..6de9634 100644 --- a/examples/architext/architext_test.py +++ b/examples/architext/architext_test.py @@ -51,7 +51,7 @@ class ArchitextModel(BaseModel): def __init__(self): super().__init__() - self.model = AutoModelForCausalLM.from_pretrained("/Users/emerson/gptj-162M") + self.model = AutoModelForCausalLM.from_pretrained("architext/gptj-162M") self.tokenizer = AutoTokenizer.from_pretrained("gpt2") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -119,7 +119,6 @@ def receive(self, *inp): result_layout_with_removed.pop(removed_space_name, None) task.data.result['layout_after_removed'] = json.dumps(result_layout_with_removed) - # print("task:", task) return task def present(self, task): From abc36ac9a0e9a17affa185ba7883b3bbd6756470 Mon Sep 17 00:00:00 2001 From: Emerson Havener Date: Tue, 13 Dec 2022 18:53:34 -0800 Subject: [PATCH 3/4] do not save pickle of image to csv --- examples/architext/architext_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/architext/architext_test.py b/examples/architext/architext_test.py index 3acf1a4..2942d6d 100644 --- a/examples/architext/architext_test.py +++ b/examples/architext/architext_test.py @@ -38,7 +38,8 @@ def post(self, batch_element : ArchitextBatchElement): { "prompt" : data.prompt, "creativity" : data.creativity, - "result" : pickle.dumps(data.result), + "layout" : data.result['layout'], + "layout_after_removed" : data.result['layout_after_removed'], "feedback" : data.feedback, "score" : int(data.score), "rule" : data.rule, @@ -184,7 +185,7 @@ def space_names_camel_case_to_title_case(self, space_names): cheese = CHEESE( ArchitextPipeline, ArchitextFront, ArchitextModel, pipeline_kwargs = { - "write_path" : "./architext_dataset_res", "force_new" : True + "write_path" : "./dataset/architext_dataset_res.csv", "force_new" : True } ) url = cheese.launch() From 2ce977b8b777c26d80023ba5255b2ca83c148807 Mon Sep 17 00:00:00 2001 From: Emerson Havener Date: Tue, 13 Dec 2022 19:36:37 -0800 Subject: [PATCH 4/4] do not replace existing records, save images outside of dataset files (in dataset dir, named by prompt and timestamp) --- examples/architext/architext_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/architext/architext_test.py b/examples/architext/architext_test.py index 2942d6d..56f8d02 100644 --- a/examples/architext/architext_test.py +++ b/examples/architext/architext_test.py @@ -6,10 +6,10 @@ from dataclasses import dataclass from examples.architext.architext_util import prompt_to_layout +from datetime import datetime import gradio as gr import json import random -import pickle import time import torch from transformers import AutoTokenizer, AutoModelForCausalLM @@ -34,12 +34,19 @@ def fetch(self) -> ArchitextBatchElement: def post(self, batch_element : ArchitextBatchElement): data = batch_element + layout_image_path = "./dataset/" + str(datetime.timestamp(datetime.now())) + "_" + data.prompt + ".png" + layout_image = data.result['image'] + layout_image.save(layout_image_path) + layout_after_removed = None + if data.result is not None and 'layout_after_removed' in data.result.keys(): + layout_after_removed = data.result['layout_after_removed'] self.add_row_to_dataset( { "prompt" : data.prompt, "creativity" : data.creativity, "layout" : data.result['layout'], - "layout_after_removed" : data.result['layout_after_removed'], + "layout_image_path" : layout_image_path, + "layout_after_removed" : layout_after_removed, "feedback" : data.feedback, "score" : int(data.score), "rule" : data.rule, @@ -185,7 +192,7 @@ def space_names_camel_case_to_title_case(self, space_names): cheese = CHEESE( ArchitextPipeline, ArchitextFront, ArchitextModel, pipeline_kwargs = { - "write_path" : "./dataset/architext_dataset_res.csv", "force_new" : True + "write_path" : "./dataset/architext_dataset_res.csv", "force_new" : False } ) url = cheese.launch()