diff --git a/backend/db_repo.py b/backend/db_repo.py index b6f3876c..782c61e2 100644 --- a/backend/db_repo.py +++ b/backend/db_repo.py @@ -917,6 +917,80 @@ def create_timing(self, **kwargs): return InternalResponse(payload, 'timing created successfully', True) + def bulk_create_timing(self, timing_data_list): + primary_image_id_list = [] + shot_id_list = [] + model_id_list = [] + source_image_id_list = [] + mask_id_list = [] + canny_image_id_list = [] + + for d in timing_data_list: + if 'primary_image_id' in d: + primary_image_id_list.append(d['primary_image_id']) + if 'shot_id' in d: + shot_id_list.append(d['shot_id']) + if 'model_id' in d: + model_id_list.append(d['model_id']) + if 'source_image_id' in d: + source_image_id_list.append(d['source_image_id']) + if 'mask_id' in d: + mask_id_list.append(d['mask_id']) + if 'canny_image_id' in d: + canny_image_id_list.append(d['canny_image_id']) + if 'primary_image_id' in d: + primary_image_id_list.append(d['primary_image_id']) + + primary_image_id_list = list(set(primary_image_id_list)) + shot_id_list = list(set(shot_id_list)) + model_id_list = list(set(model_id_list)) + source_image_id_list = list(set(source_image_id_list)) + mask_id_list = list(set(mask_id_list)) + canny_image_id_list = list(set(canny_image_id_list)) + + file_list = InternalFileObject.objects.filter(uuid__in=primary_image_id_list + source_image_id_list + mask_id_list + canny_image_id_list, is_disabled=False).all() + model_list = AIModel.objects.filter(uuid__in=model_id_list, is_disabled=False).all() + shot_list = Shot.objects.filter(uuid__in=shot_id_list, is_disabled=False).all() + + file_uuid_id_map = {str(file.uuid): file.id for file in file_list} + model_uuid_id_map = {str(model.uuid): model.id for model in model_list} + shot_uuid_id_map = {str(shot.uuid): shot.id for shot in shot_list} + + # print("----- file_uuid_map: ", file_uuid_id_map) + res_timing_list = [] + for data in timing_data_list: + kwargs = data + if 'primary_image_id' in kwargs and kwargs['primary_image_id'] in file_uuid_id_map: + kwargs['primary_image_id'] = file_uuid_id_map[kwargs['primary_image_id']] + + if 'source_image_id' in kwargs and kwargs['source_image_id'] in file_uuid_id_map: + kwargs['source_image_id'] = file_uuid_id_map[kwargs['source_image_id']] + + if 'mask_id' in kwargs and kwargs['mask_id'] in file_uuid_id_map: + kwargs['mask_id'] = file_uuid_id_map[kwargs['mask_id']] + + if 'canny_image_id' in kwargs and kwargs['canny_image_id'] in file_uuid_id_map: + kwargs['canny_image_id'] = file_uuid_id_map[kwargs['canny_image_id']] + + if 'shot_id' in kwargs and kwargs['shot_id'] in shot_uuid_id_map: + kwargs['shot_id'] = shot_uuid_id_map[kwargs['shot_id']] + + if 'model_id' in kwargs and kwargs['model_id'] in model_uuid_id_map: + kwargs['model_id'] = model_uuid_id_map[kwargs['model_id']] + + # print("---- data: ", kwargs) + timing = Timing(**kwargs) + res_timing_list.append(timing) + + with transaction.atomic(): + for timing in res_timing_list: + timing.save() + + payload = { + 'data': [TimingDto(timing).data for timing in res_timing_list] + } + return InternalResponse(payload, 'timing list created successfully', True) + def remove_existing_timing(self, project_uuid): if project_uuid: project: Project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first() @@ -944,6 +1018,84 @@ def add_interpolated_clip(self, uuid, **kwargs): return InternalResponse({}, 'success', True) + def update_bulk_timing(self, timing_uuid_list, data_list): + timing_list = Timing.objects.filter(uuid__in=timing_uuid_list, is_disabled=False).all() + if not (timing_list and len(timing_list)) and len(timing_uuid_list): + return InternalResponse({}, 'no timing objs found', False) + + primary_image_id_list = [] + shot_id_list = [] + model_id_list = [] + source_image_id_list = [] + mask_id_list = [] + canny_image_id_list = [] + + for d in data_list: + if 'primary_image_id' in d: + primary_image_id_list.append(d['primary_image_id']) + if 'shot_id' in d: + shot_id_list.append(d['shot_id']) + if 'model_id' in d: + model_id_list.append(d['model_id']) + if 'source_image_id' in d: + source_image_id_list.append(d['source_image_id']) + if 'mask_id' in d: + mask_id_list.append(d['mask_id']) + if 'canny_image_id' in d: + canny_image_id_list.append(d['canny_image_id']) + if 'primary_image_id' in d: + primary_image_id_list.append(d['primary_image_id']) + + primary_image_id_list = list(set(primary_image_id_list)) + shot_id_list = list(set(shot_id_list)) + model_id_list = list(set(model_id_list)) + source_image_id_list = list(set(source_image_id_list)) + mask_id_list = list(set(mask_id_list)) + canny_image_id_list = list(set(canny_image_id_list)) + + file_list = InternalFileObject.objects.filter(uuid__in=primary_image_id_list + source_image_id_list + mask_id_list + canny_image_id_list, is_disabled=False).all() + model_list = AIModel.objects.filter(uuid__in=model_id_list, is_disabled=False).all() + shot_list = Shot.objects.filter(uuid__in=shot_id_list, is_disabled=False).all() + + file_uuid_id_map = {file.uuid: file.id for file in file_list} + model_uuid_id_map = {model.uuid: model.id for model in model_list} + shot_uuid_id_map = {shot.uuid: shot.id for shot in shot_list} + + res_timing_list = [] + for timing, update_data in zip(timing_list, data_list): + kwargs = update_data + if 'primary_image_id' in kwargs and kwargs['primary_image_id'] in file_uuid_id_map: + kwargs['primary_image_id'] = file_uuid_id_map[kwargs['primary_image_id']] + + if 'source_image_id' in kwargs and kwargs['source_image_id'] in file_uuid_id_map: + kwargs['source_image_id'] = file_uuid_id_map[kwargs['source_image_id']] + + if 'mask_id' in kwargs and kwargs['mask_id'] in file_uuid_id_map: + kwargs['mask_id'] = file_uuid_id_map[kwargs['mask_id']] + + if 'canny_image_id' in kwargs and kwargs['canny_image_id'] in file_uuid_id_map: + kwargs['canny_image_id'] = file_uuid_id_map[kwargs['canny_image_id']] + + if 'shot_id' in kwargs and kwargs['shot_id'] in shot_uuid_id_map: + kwargs['shot_id'] = shot_uuid_id_map[kwargs['shot_id']] + + if 'model_id' in kwargs and kwargs['model_id'] in model_uuid_id_map: + kwargs['model_id'] = model_uuid_id_map[kwargs['model_id']] + + for attr, value in kwargs.items(): + setattr(timing, attr, value) + + res_timing_list.append(timing) + + with transaction.atomic(): + for timing in res_timing_list: + timing.save() + + payload = { + 'data': [TimingDto(timing).data for timing in res_timing_list] + } + return InternalResponse(payload, 'timing list updated successfully', True) + def update_specific_timing(self, uuid, **kwargs): timing = Timing.objects.filter(uuid=uuid, is_disabled=False).first() if not timing: @@ -999,15 +1151,7 @@ def update_specific_timing(self, uuid, **kwargs): return InternalResponse({}, 'invalid canny image uuid', False) kwargs['canny_image_id'] = canny_image.id - - if 'primay_image_id' in kwargs: - if kwargs['primay_image_id'] != None: - primay_image: InternalFileObject = InternalFileObject.objects.filter(uuid=kwargs['primay_image_id'], is_disabled=False).first() - if not primay_image: - return InternalResponse({}, 'invalid primary image uuid', False) - - kwargs['primay_image_id'] = primay_image.id for attr, value in kwargs.items(): setattr(timing, attr, value) diff --git a/readme.md b/readme.md index 77e5c7e2..78b9ee83 100644 --- a/readme.md +++ b/readme.md @@ -1,4 +1,4 @@ -# Welcome to Dough v. 0.8.3 (beta) +# Welcome to Dough v. 0.9.0 (beta) **⬇️ Scroll down for Setup Instructions - Currently available on Linux & Windows, hosted version coming soon.** @@ -10,28 +10,27 @@ Below is brief overview and some examples of outputs: ### With Dough, you can makes guidance frames using Stable Diffusion XL, IP-Adapter, Fooocus Inpainting, and more: - + ### You can then assemble these frames into shots that you can granularly edit: - + ### And then animate these shots by defining parameters for each frame and selecting guidance videos via Motion LoRAs: - + ### As an example, here's a video that's guided with just images on high strength: - - + ### While here's a more complex one, with low strength images driving it alongside a guidance video: - + ### And here's a more complex example combining high strength guidance with a guidance video strongly influencing the motion: - + @@ -53,16 +52,16 @@ Below is brief overview and some examples of outputs: 3) During setup, open the relevant ports for Dough like below: - + - + 4) When you’ve launched the pod, click into Jupyter Notebook: - + - + 5) Follow the “Setup for Linux” below and come back here when you’ve gone through them. @@ -70,11 +69,11 @@ Below is brief overview and some examples of outputs: 6) Once you’re done that, grab the IP Address for your instance: - + - + - + Then form put these into this form with a : between them like this: @@ -155,3 +154,51 @@ cd Dough --- If you're having any issues, please share them in our [Discord](https://discord.com/invite/8Wx9dFu5tP). + +# Troubleshooting + +
+ Common problems (click to expand) + +
+ Issue during installation + +- Make sure you are using python3.10 +- If you are on Windows, make sure permissions of the Dough folder are not restricted (try to grant full access to everyone) +- Double-check that you are not inside any system-protected folders like system32 +- Install the app in admin mode. Open the powershell in the admin mode and run "Set-ExecutionPolicy RemoteSigned". Then follow the installation instructions given in the readme +- If all of the above fail, try to run the following instructions one by one and report which one is throwing the error + ```bash + call dough-env\Scripts\activate.bat + python.exe -m pip install --upgrade pip + pip install -r requirements.txt + pip install websocket + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 + pip install -r comfy_runner\requirements.txt + pip install -r ComfyUI\requirements.txt + ``` +
+
+ Unable to locate credentials + Make a copy of ".env.sample" and rename it to ".env" +
+
+ Issue during runtime + +- If a particular node inside Comfy is throwing an error then delete that node and restart the app +- Make sure you are using python3.10 and the virtual environment is activated +- Try doing "git pull origin main" to get the latest code +
+
+ Generations are in progress for a long time + +- Check the terminal if any progress is being made (they can be very slow, especially in the case of upscaling) +- Cancel the generations directly from the sidebar if they are stuck +- If you don't see any logs in the terminal, make sure no other program is running at port 12345 on your machine as Dough uses that port +
+
+ Some other error? + + Drop in our [Discord](https://discord.com/invite/8Wx9dFu5tP). +
+
diff --git a/sample_assets/example_generations/guy-1.png b/sample_assets/example_generations/guy-1.png deleted file mode 100644 index e9f81cd4..00000000 Binary files a/sample_assets/example_generations/guy-1.png and /dev/null differ diff --git a/sample_assets/example_generations/guy-2.png b/sample_assets/example_generations/guy-2.png deleted file mode 100644 index c0872778..00000000 Binary files a/sample_assets/example_generations/guy-2.png and /dev/null differ diff --git a/sample_assets/example_generations/lady-1.png b/sample_assets/example_generations/lady-1.png deleted file mode 100644 index 49c8126a..00000000 Binary files a/sample_assets/example_generations/lady-1.png and /dev/null differ diff --git a/sample_assets/example_generations/world-1.png b/sample_assets/example_generations/world-1.png deleted file mode 100644 index 4a2d7204..00000000 Binary files a/sample_assets/example_generations/world-1.png and /dev/null differ diff --git a/sample_assets/example_generations/world-2.png b/sample_assets/example_generations/world-2.png deleted file mode 100644 index 9902036d..00000000 Binary files a/sample_assets/example_generations/world-2.png and /dev/null differ diff --git a/sample_assets/example_generations/world-3.png b/sample_assets/example_generations/world-3.png deleted file mode 100644 index 74715ba9..00000000 Binary files a/sample_assets/example_generations/world-3.png and /dev/null differ diff --git a/sample_assets/example_generations/world-4.png b/sample_assets/example_generations/world-4.png deleted file mode 100644 index 702db9bb..00000000 Binary files a/sample_assets/example_generations/world-4.png and /dev/null differ diff --git a/sample_assets/sample_images/1bd04ac9d2e347c2f84e299a7b8f5bd9a3eca4e097638cfcdbe1bf8c.jpg b/sample_assets/sample_images/1bd04ac9d2e347c2f84e299a7b8f5bd9a3eca4e097638cfcdbe1bf8c.jpg deleted file mode 100644 index 77baed27..00000000 Binary files a/sample_assets/sample_images/1bd04ac9d2e347c2f84e299a7b8f5bd9a3eca4e097638cfcdbe1bf8c.jpg and /dev/null differ diff --git a/sample_assets/sample_images/cat_walking.gif b/sample_assets/sample_images/cat_walking.gif deleted file mode 100644 index baaf7a5a..00000000 Binary files a/sample_assets/sample_images/cat_walking.gif and /dev/null differ diff --git a/sample_assets/sample_images/complex.gif b/sample_assets/sample_images/complex.gif deleted file mode 100644 index 33ea1c72..00000000 Binary files a/sample_assets/sample_images/complex.gif and /dev/null differ diff --git a/sample_assets/sample_images/generation_example.png b/sample_assets/sample_images/generation_example.png deleted file mode 100644 index 4e4633ab..00000000 Binary files a/sample_assets/sample_images/generation_example.png and /dev/null differ diff --git a/sample_assets/sample_images/just_images.gif b/sample_assets/sample_images/just_images.gif deleted file mode 100644 index 79cdf630..00000000 Binary files a/sample_assets/sample_images/just_images.gif and /dev/null differ diff --git a/sample_assets/sample_images/main.png b/sample_assets/sample_images/main.png deleted file mode 100644 index c99f6ad7..00000000 Binary files a/sample_assets/sample_images/main.png and /dev/null differ diff --git a/sample_assets/sample_images/main_example.gif b/sample_assets/sample_images/main_example.gif deleted file mode 100644 index c9d31476..00000000 Binary files a/sample_assets/sample_images/main_example.gif and /dev/null differ diff --git a/sample_assets/sample_images/motion.png b/sample_assets/sample_images/motion.png deleted file mode 100644 index 5c9c98d0..00000000 Binary files a/sample_assets/sample_images/motion.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_1.png b/sample_assets/sample_images/runpod_1.png deleted file mode 100644 index 0f083613..00000000 Binary files a/sample_assets/sample_images/runpod_1.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_2.png b/sample_assets/sample_images/runpod_2.png deleted file mode 100644 index 6edba8e1..00000000 Binary files a/sample_assets/sample_images/runpod_2.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_3.png b/sample_assets/sample_images/runpod_3.png deleted file mode 100644 index 728c9de1..00000000 Binary files a/sample_assets/sample_images/runpod_3.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_4.png b/sample_assets/sample_images/runpod_4.png deleted file mode 100644 index 38671eb9..00000000 Binary files a/sample_assets/sample_images/runpod_4.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_5.png b/sample_assets/sample_images/runpod_5.png deleted file mode 100644 index 7e5f9c48..00000000 Binary files a/sample_assets/sample_images/runpod_5.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_6.png b/sample_assets/sample_images/runpod_6.png deleted file mode 100644 index db3a8e97..00000000 Binary files a/sample_assets/sample_images/runpod_6.png and /dev/null differ diff --git a/sample_assets/sample_images/runpod_7.png b/sample_assets/sample_images/runpod_7.png deleted file mode 100644 index e6f216da..00000000 Binary files a/sample_assets/sample_images/runpod_7.png and /dev/null differ diff --git a/sample_assets/sample_images/shot_example.png b/sample_assets/sample_images/shot_example.png deleted file mode 100644 index 52ebf09e..00000000 Binary files a/sample_assets/sample_images/shot_example.png and /dev/null differ diff --git a/sample_assets/sample_images/tweak_settings.gif b/sample_assets/sample_images/tweak_settings.gif deleted file mode 100644 index 11efcca8..00000000 Binary files a/sample_assets/sample_images/tweak_settings.gif and /dev/null differ diff --git a/sample_assets/sample_videos/sample.mp4 b/sample_assets/sample_videos/sample.mp4 deleted file mode 100644 index ff006eb8..00000000 Binary files a/sample_assets/sample_videos/sample.mp4 and /dev/null differ diff --git a/scripts/app_version.txt b/scripts/app_version.txt index ee94dd83..ac39a106 100644 --- a/scripts/app_version.txt +++ b/scripts/app_version.txt @@ -1 +1 @@ -0.8.3 +0.9.0 diff --git a/ui_components/components/animate_shot_page.py b/ui_components/components/animate_shot_page.py index d24471e5..50488bef 100644 --- a/ui_components/components/animate_shot_page.py +++ b/ui_components/components/animate_shot_page.py @@ -1,7 +1,7 @@ import json import time import streamlit as st -from shared.constants import InferenceParamType, InferenceStatus, InferenceType, InternalFileType +from shared.constants import AnimationStyleType, InferenceParamType, InferenceStatus, InferenceType, InternalFileType from ui_components.components.video_rendering_page import sm_video_rendering_page, two_img_realistic_interpolation_page from ui_components.models import InternalShotObject from ui_components.widgets.frame_selector import frame_selector_widget @@ -49,6 +49,7 @@ def video_rendering_page(shot_uuid, selected_variant): log = data_repo.get_inference_log_from_uuid(selected_variant) shot_data = json.loads(log.input_params) file_uuid_list = shot_data.get('origin_data', json.dumps({})).get('settings', {}).get('file_uuid_list', []) + st.session_state[f"{shot_uuid}_selected_variant_log_uuid"] = None else: # hackish sol, will fix later @@ -65,6 +66,34 @@ def video_rendering_page(shot_uuid, selected_variant): for timing in shot.timing_list: if timing.primary_image and timing.primary_image.location: file_uuid_list.append(timing.primary_image.uuid) + else: + # updating the shot timing images + shot_timing_list = shot.timing_list + img_mismatch = False # flag to check if shot images need to be updated + if len(file_uuid_list) == len(shot_timing_list): + for file_uuid, timing in zip(file_uuid_list, shot_timing_list): + if timing.primary_image and timing.primary_image.uuid != file_uuid: + img_mismatch = True + break + else: + img_mismatch = True + + if img_mismatch or len(file_uuid_list) != len(shot_timing_list): + # deleting all the current timings + data_repo.update_bulk_timing([timing.uuid for timing in shot_timing_list], [{'is_disabled': True}] * len(shot_timing_list)) + # adding new timings + new_timing_data = [] + for idx, file_uuid in enumerate(file_uuid_list): + new_timing_data.append( + { + 'aux_frame_index': idx, + 'shot_id': shot_uuid, + 'primary_image_id': file_uuid, + 'is_disabled': False + } + ) + + data_repo.bulk_create_timing(new_timing_data) img_list = data_repo.get_all_file_list(uuid__in=file_uuid_list, file_type=InternalFileType.IMAGE.value)[0] diff --git a/ui_components/components/explorer_page.py b/ui_components/components/explorer_page.py index 95baa908..719aa476 100644 --- a/ui_components/components/explorer_page.py +++ b/ui_components/components/explorer_page.py @@ -491,7 +491,7 @@ def gallery_image_view(project_uuid, shortlist=False, view=["main"], shot=None, else: with h1: project_setting = data_repo.get_project_setting(project_uuid) - page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True, key="shortlist_gallery") open_detailed_view_for_all = False else: @@ -620,7 +620,9 @@ def gallery_image_view(project_uuid, shortlist=False, view=["main"], shot=None, log = gallery_image_list[i + j].inference_log # data_repo.get_inference_log_from_uuid(gallery_image_list[i + j].inference_log.uuid) if log: input_params = json.loads(log.input_params) - prompt = input_params.get('prompt', 'No prompt found') + prompt = input_params.get('prompt', None) + if not prompt: + prompt = input_params.get("query_dict", {}).get("prompt", "Prompt not found") model = json.loads(log.output_details)['model_name'].split('/')[-1] if 'view_inference_details' in view: with st.expander("Prompt Details", expanded=open_detailed_view_for_all): diff --git a/ui_components/components/project_settings_page.py b/ui_components/components/project_settings_page.py index ca60f440..59897bba 100644 --- a/ui_components/components/project_settings_page.py +++ b/ui_components/components/project_settings_page.py @@ -33,17 +33,30 @@ def project_settings_page(project_uuid): with v1: st.write("Current Size = ", project_settings.width, "x", project_settings.height) - frame_size = st.radio("Select frame size:", options=frame_sizes, index=current_index, key="frame_size", horizontal=True) - width, height = map(int, frame_size.split('x')) + custom_frame_size = st.checkbox("Enter custom frame size", value=False) + err = False + if not custom_frame_size: + frame_size = st.radio("Select frame size:", options=frame_sizes, index=current_index, key="frame_size", horizontal=True) + width, height = map(int, frame_size.split('x')) + else: + st.info("This is an experimental feature") + width = st.text_input("Width", value=512) + height = st.text_input("Height", value=512) + try: + width, height = int(width), int(height) + err = False + except Exception as e: + st.error("Please input integer values") + err = True - - img = Image.new('RGB', (width, height), color = (73, 109, 137)) - st.image(img, width=70) - - if st.button("Save"): - data_repo.update_project_setting(project_uuid, width=width) - data_repo.update_project_setting(project_uuid, height=height) - st.experimental_rerun() + if not err: + img = Image.new('RGB', (width, height), color = (73, 109, 137)) + st.image(img, width=70) + + if st.button("Save"): + data_repo.update_project_setting(project_uuid, width=width) + data_repo.update_project_setting(project_uuid, height=height) + st.experimental_rerun() st.write("") st.write("") diff --git a/ui_components/models.py b/ui_components/models.py index 0c051278..fefd7bcf 100644 --- a/ui_components/models.py +++ b/ui_components/models.py @@ -179,7 +179,7 @@ def __init__(self, **kwargs): self.shot_idx = kwargs['shot_idx'] if key_present('shot_idx', kwargs) else 0 self.duration = kwargs['duration'] if key_present('duration', kwargs) else 0 self.meta_data = kwargs['meta_data'] if key_present('meta_data', kwargs) else {} - self.timing_list = [InternalFrameTimingObject(**timing) for timing in kwargs["timing_list"]] \ + self.timing_list = [InternalFrameTimingObject(**timing) for timing in sorted(kwargs["timing_list"], key=lambda x: x['aux_frame_index'])] \ if key_present('timing_list', kwargs) and kwargs["timing_list"] else [] self.interpolated_clip_list = [InternalFileObject(**vid) for vid in kwargs['interpolated_clip_list']] if key_present('interpolated_clip_list', kwargs) \ else [] diff --git a/ui_components/widgets/variant_comparison_grid.py b/ui_components/widgets/variant_comparison_grid.py index 3581f499..b91e23e8 100644 --- a/ui_components/widgets/variant_comparison_grid.py +++ b/ui_components/widgets/variant_comparison_grid.py @@ -62,11 +62,11 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): variants: List[InternalFileObject] = shot.interpolated_clip_list timing_list = data_repo.get_timing_list_from_shot(shot.uuid) - if not (f"{shot_uuid}_selected_variant_log_uuid" in st.session_state and st.session_state[f"{shot_uuid}_selected_variant_log_uuid"]): - # if variants and len(variants): - # st.session_state[f"{shot_uuid}_selected_variant_log_uuid"] = variants[-1].inference_log.uuid - # else: - st.session_state[f"{shot_uuid}_selected_variant_log_uuid"] = None + # if not (f"{shot_uuid}_selected_variant_log_uuid" in st.session_state and st.session_state[f"{shot_uuid}_selected_variant_log_uuid"]): + # # if variants and len(variants): + # # st.session_state[f"{shot_uuid}_selected_variant_log_uuid"] = variants[-1].inference_log.uuid + # # else: + # st.session_state[f"{shot_uuid}_selected_variant_log_uuid"] = None else: timing_uuid = ele_uuid timing = data_repo.get_timing_from_uuid(timing_uuid) @@ -201,11 +201,14 @@ def is_upscaled_video(variant: InternalFileObject): def image_variant_details(variant: InternalFileObject): with st.expander("Inference Details", expanded=False): if variant.inference_params and 'query_dict' in variant.inference_params: - query_dict = json.loads(variant.inference_params['query_dict']) + query_dict = json.loads(variant.inference_params['query_dict']) if \ + isinstance(variant.inference_params['query_dict'], str) else variant.inference_params['query_dict'] st.markdown(f"Prompt: {query_dict['prompt']}", unsafe_allow_html=True) st.markdown(f"Negative Prompt: {query_dict['negative_prompt']}", unsafe_allow_html=True) - st.markdown(f"Dimension: {query_dict['width']}x{query_dict['height']}", unsafe_allow_html=True) - st.markdown(f"Guidance scale: {query_dict['guidance_scale']}", unsafe_allow_html=True) + if 'width' in query_dict: + st.markdown(f"Dimension: {query_dict['width']}x{query_dict['height']}", unsafe_allow_html=True) + if 'guidance_scale' in query_dict: + st.markdown(f"Guidance scale: {query_dict['guidance_scale']}", unsafe_allow_html=True) model_name = variant.inference_log.model_name st.markdown(f"Model name: {model_name}", unsafe_allow_html=True) if model_name in []: diff --git a/utils/app_update_utils.py b/utils/app_update_utils.py index 932d45ba..dc8548fd 100644 --- a/utils/app_update_utils.py +++ b/utils/app_update_utils.py @@ -7,11 +7,16 @@ import streamlit as st import threading import sys +from git import Repo from streamlit_server_state import server_state_lock from utils.data_repo.data_repo import DataRepo update_event = threading.Event() +dough_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +comfy_runner_dir = os.path.join(dough_dir, "comfy_runner") +comfy_ui_dir = os.path.join(dough_dir, 'ComfyUI') + def check_for_updates(): if not os.path.exists('banodoco_local.db'): return @@ -32,6 +37,7 @@ def check_for_updates(): update_thread.join() update_event.wait() st.session_state['update_in_progress'] = False + st.session_state['first_load'] = True st.rerun() def update_app(): @@ -39,53 +45,52 @@ def update_app(): update_dough() update_comfy_runner() update_comfy_ui() - except subprocess.CalledProcessError as e: + except Exception as e: print("Update failed:", str(e)) def update_comfy_runner(): - if os.path.exists("comfy_runner/"): - os.chdir("comfy_runner/") - subprocess.run(["git", "stash"], check=True) - completed_process = subprocess.run(["git", "pull", "origin", "feature/package"], check=True) - if completed_process.returncode == 0: - print("Comfy runner updated") + if os.path.exists(comfy_runner_dir): + os.chdir(comfy_runner_dir) + try: + update_git_repo(comfy_runner_dir) + except Exception as e: + print(f"Error occured: {str(e)}") + + print("Comfy runner updated") move_to_root() def update_dough(): print("Updating the app...") - subprocess.run(["git", "stash"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - - completed_process = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], check=True, capture_output=True, text=True) - current_branch = completed_process.stdout.strip() - subprocess.run(["git", "pull", "origin", current_branch], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - + + try: + update_git_repo(dough_dir) + except Exception as e: + print(f"Error occurred: {str(e)}") + if os.path.exists("banodoco_local.db"): python_executable = sys.executable completed_process = subprocess.run([python_executable, 'manage.py', 'migrate'], capture_output=True, text=True) if completed_process.returncode == 0: print("Database migration successful") - + if os.path.exists(".env"): sample_env = dotenv_values('.env.sample') env = dotenv_values('.env') missing_keys = [key for key in sample_env if key not in env] - for key in missing_keys: env[key] = sample_env[key] - with open(".env", 'w') as f: for key, value in env.items(): f.write(f"{key}={value}\n") - print("env update successful") - + move_to_root() def update_comfy_ui(): global update_event - custom_nodes_dir = "ComfyUI/custom_nodes" + custom_nodes_dir = os.path.join(comfy_ui_dir, "custom_nodes") if os.path.exists(custom_nodes_dir): - initial_dir = os.getcwd() + initial_dir = dough_dir for folder in os.listdir(custom_nodes_dir): folder_path = os.path.join(custom_nodes_dir, folder) if os.path.isdir(folder_path) and os.path.exists(os.path.join(folder_path, ".git")): @@ -122,7 +127,13 @@ def update_comfy_ui(): update_event.set() move_to_root() - + +def update_git_repo(git_dir): + repo = Repo(git_dir) + current_branch = repo.active_branch + repo.git.stash() + repo.remotes.origin.pull(current_branch.name) + def get_local_version(): file_path = "./scripts/app_version.txt" try: @@ -130,10 +141,23 @@ def get_local_version(): return file.read().strip() except Exception as e: return None + +def get_current_branch(git_dir): + if not is_git_initialized(git_dir): + init_git("../", "https://github.com/banodoco/Dough.git") + + try: + completed_process = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], check=True, capture_output=True, text=True, cwd=git_dir) + current_branch = completed_process.stdout.strip() + except Exception as e: + print("------ exception occured: ", str(e)) + current_branch = "main" + return current_branch + def get_remote_version(): - completed_process = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], check=True, capture_output=True, text=True) - current_branch = completed_process.stdout.strip() + current_branch = get_current_branch(dough_dir) + url = f"https://raw.githubusercontent.com/banodoco/Dough/{current_branch}/scripts/app_version.txt" try: response = requests.get(url) @@ -141,6 +165,26 @@ def get_remote_version(): return response.text.strip() except Exception as e: return None + +def run_command(command): + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + output, error = process.communicate() + return output.decode('utf-8'), error.decode('utf-8') + +def is_git_initialized(repo_folder): + os.chdir(repo_folder) + output, error = run_command("git rev-parse --is-inside-work-tree") + os.chdir(dough_dir) + return output.strip() == "true" + +def init_git(repo_folder, repo_url): + os.chdir(repo_folder) + run_command("git init") + run_command(f"git remote add origin {repo_url}") + run_command("git fetch origin") + run_command("git reset --hard origin/main") + + os.chdir("..") def compare_versions(version1, version2): ver1 = [int(x) for x in version1.split(".")] @@ -164,7 +208,8 @@ def compare_versions(version1, version2): return 0 def move_to_root(): - current_dir = os.getcwd() - while os.path.basename(current_dir) != "Dough": - current_dir = os.path.dirname(current_dir) - os.chdir(current_dir) \ No newline at end of file + os.chdir(dough_dir) + # current_dir = os.getcwd() + # while os.path.basename(current_dir) != "Dough": + # current_dir = os.path.dirname(current_dir) + # os.chdir(current_dir) \ No newline at end of file diff --git a/utils/cache/cache_methods.py b/utils/cache/cache_methods.py index e828bce1..509bde99 100644 --- a/utils/cache/cache_methods.py +++ b/utils/cache/cache_methods.py @@ -668,6 +668,30 @@ def _cache_update_shot(self, *args, **kwargs): setattr(cls, '_original_update_shot', cls.update_shot) setattr(cls, "update_shot", _cache_update_shot) + + def _cache_update_bulk_timing(self, *args, **kwargs): + original_func = getattr(cls, '_original_update_bulk_timing') + status = original_func(self, *args, **kwargs) + + if status: + StCache.delete_all(CacheKey.SHOT.value) + + return status + + setattr(cls, '_original_update_bulk_timing', cls.update_bulk_timing) + setattr(cls, "update_bulk_timing", _cache_update_bulk_timing) + + def _cache_bulk_create_timing(self, *args, **kwargs): + original_func = getattr(cls, '_original_bulk_create_timing') + status = original_func(self, *args, **kwargs) + + if status: + StCache.delete_all(CacheKey.SHOT.value) + + return status + + setattr(cls, '_original_bulk_create_timing', cls.bulk_create_timing) + setattr(cls, "bulk_create_timing", _cache_bulk_create_timing) def _cache_delete_shot(self, *args, **kwargs): original_func = getattr(cls, '_original_delete_shot') diff --git a/utils/common_utils.py b/utils/common_utils.py index 0292f618..7b029d6d 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -104,8 +104,8 @@ def create_working_assets(project_uuid): os.makedirs(directory) # copying sample assets for new project - if new_project: - copy_sample_assets(project_uuid) + # if new_project: + # copy_sample_assets(project_uuid) def truncate_decimal(num: float, n: int = 2) -> float: return int(num * 10 ** n) / 10 ** n diff --git a/utils/data_repo/api_repo.py b/utils/data_repo/api_repo.py index 72121dcd..94638ccc 100644 --- a/utils/data_repo/api_repo.py +++ b/utils/data_repo/api_repo.py @@ -385,6 +385,11 @@ def update_specific_timing(self, uuid, **kwargs): kwargs['uuid'] = uuid res = self.http_put(url=self.TIMING_URL, data=kwargs) return InternalResponse(res['payload'], 'success', res['status']) + + # TODO: complete this + def update_bulk_timing(self, timing_uuid_list, data_list): + res = self.http_put(url=self.TIMING_LIST_URL, data=data_list) + return InternalResponse(res['payload'], 'success', res['status']) def delete_timing_from_uuid(self, uuid): res = self.http_delete(self.TIMING_URL, params={'uuid': uuid}) diff --git a/utils/data_repo/data_repo.py b/utils/data_repo/data_repo.py index 99bc2e95..d42cfda8 100644 --- a/utils/data_repo/data_repo.py +++ b/utils/data_repo/data_repo.py @@ -315,6 +315,18 @@ def update_specific_timing(self, uuid, **kwargs): res = self.db_repo.update_specific_timing(uuid, **kwargs) return res.status + # NOTE: this method focuses on speed and therefore bypasses aux_frame update for individual saves + # only use it for updating timings if their relative position is not affected + def update_bulk_timing(self, timing_uuid_list, data_list): + res = self.db_repo.update_bulk_timing(timing_uuid_list, data_list) + return res.status + + # NOTE: this method focuses on speed and therefore bypasses aux_frame update for individual saves + # only use it for updating timings if their relative position is not affected + def bulk_create_timing(self, data_list): + res = self.db_repo.bulk_create_timing(data_list) + return res.status + def delete_timing_from_uuid(self, uuid): res = self.db_repo.delete_timing_from_uuid(uuid) return res.status diff --git a/utils/ml_processor/comfy_data_transform.py b/utils/ml_processor/comfy_data_transform.py index e9180099..8858bb58 100644 --- a/utils/ml_processor/comfy_data_transform.py +++ b/utils/ml_processor/comfy_data_transform.py @@ -262,6 +262,8 @@ def transform_sdxl_inpainting_workflow(query: MLQueryObject): workflow["50"]["inputs"]["width"] = width workflow["52"]["inputs"]["height"] = height workflow["52"]["inputs"]["width"] = width + workflow["59"]["inputs"]["width"] = width + workflow["58"]["inputs"]["width"] = height return json.dumps(workflow), output_node_ids, [], [] diff --git a/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json b/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json index fdd76cbc..090e85a6 100644 --- a/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json +++ b/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json @@ -94,37 +94,43 @@ "title": "UNETLoader" } }, - "33": { + "58": { "inputs": { - "image": [ - "20", - 0 - ] + "value": 512 + }, + "class_type": "JWInteger", + "_meta": { + "title": "Height" + } + }, + "59": { + "inputs": { + "value": 512 }, - "class_type": "Get image size", + "class_type": "JWInteger", "_meta": { - "title": "Get image size" + "title": "Width" } }, "34": { "inputs": { "width": [ - "33", + "59", 0 ], "height": [ - "33", - 1 + "58", + 0 ], "crop_w": 0, "crop_h": 0, "target_width": [ - "33", + "59", 0 ], "target_height": [ - "33", - 1 + "58", + 0 ], "text_g": "man fishing, ZipRealism, Zip2D", "text_l": "man fishing, ZipRealism, Zip2D",