Skip to content

Commit

Permalink
Merge pull request #39 from CarperAI/architext_import_updates
Browse files Browse the repository at this point in the history
Architext import update
  • Loading branch information
shahbuland committed Dec 14, 2022
2 parents 58408a5 + 2ce977b commit 526958c
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions examples/architext/architext_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from cheese.api import CHEESE
from cheese import CHEESE
from cheese.client.gradio_client import GradioFront, InvalidInputException
from cheese.data import BatchElement
from cheese.models import BaseModel
from cheese.pipeline.write_only import WriteOnlyPipeline
from dataclasses import dataclass
from examples.architext.architext_util import prompt_to_layout

from datetime import datetime
import gradio as gr
import json
import random
import pickle
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand All @@ -34,11 +34,19 @@ def fetch(self) -> ArchitextBatchElement:

def post(self, batch_element : ArchitextBatchElement):
data = batch_element
layout_image_path = "./dataset/" + str(datetime.timestamp(datetime.now())) + "_" + data.prompt + ".png"
layout_image = data.result['image']
layout_image.save(layout_image_path)
layout_after_removed = None
if data.result is not None and 'layout_after_removed' in data.result.keys():
layout_after_removed = data.result['layout_after_removed']
self.add_row_to_dataset(
{
"prompt" : data.prompt,
"creativity" : data.creativity,
"result" : pickle.dumps(data.result),
"layout" : data.result['layout'],
"layout_image_path" : layout_image_path,
"layout_after_removed" : layout_after_removed,
"feedback" : data.feedback,
"score" : int(data.score),
"rule" : data.rule,
Expand Down Expand Up @@ -120,7 +128,6 @@ def receive(self, *inp):
result_layout_with_removed.pop(removed_space_name, None)

task.data.result['layout_after_removed'] = json.dumps(result_layout_with_removed)

return task

def present(self, task):
Expand Down Expand Up @@ -185,7 +192,7 @@ def space_names_camel_case_to_title_case(self, space_names):
cheese = CHEESE(
ArchitextPipeline, ArchitextFront, ArchitextModel,
pipeline_kwargs = {
"write_path" : "./architext_dataset_res", "force_new" : True
"write_path" : "./dataset/architext_dataset_res.csv", "force_new" : False
}
)
url = cheese.launch()
Expand Down

0 comments on commit 526958c

Please sign in to comment.