diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 49f22afb..daa42a6b 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -5,6 +5,7 @@ from easygui import msgbox, ynbox from typing import Optional from .custom_logging import setup_logging +from .sd_modeltype import SDModelType import os import re @@ -1014,6 +1015,13 @@ def set_pretrained_model_name_or_path_input( v_parameterization = gr.Checkbox(visible=True) sdxl = gr.Checkbox(visible=True) + # Auto-detect model type if safetensors file path is given + if pretrained_model_name_or_path.lower().endswith(".safetensors"): + detect = SDModelType(pretrained_model_name_or_path) + v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True) + sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True) + #TODO: v_parameterization + # If a refresh method is provided, use it to update the choices for the Dropdown widget if refresh_method is not None: args = dict( diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index 62b12fd9..f1650e7f 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -12,6 +12,7 @@ ) from .custom_logging import setup_logging +from .sd_modeltype import SDModelType # Set up logging log = setup_logging() @@ -337,6 +338,19 @@ def change_sdxl(sdxl): outputs=[load_tuned_model_to, load_original_model_to], ) + #secondary event on model_tuned for auto-detection of v2/SDXL + def change_modeltype_model_tuned(path): + detect = SDModelType(path) + v2 = gr.Checkbox(value=detect.Is_SD2()) + sdxl = gr.Checkbox(value=detect.Is_SDXL()) + return v2, sdxl + + model_tuned.change( + change_modeltype_model_tuned, + inputs=model_tuned, + outputs=[v2, sdxl] + ) + extract_button = gr.Button("Extract LoRA model") extract_button.click( diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index a3337c4c..72e63212 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -16,6 +16,7 @@ create_refresh_button, setup_environment ) from .custom_logging import setup_logging +from .sd_modeltype import SDModelType # Set up logging log = setup_logging() @@ -145,6 +146,13 @@ def list_save_to(path): show_progress=False, ) + #secondary event on sd_model for auto-detection of SDXL + sd_model.change( + lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + inputs=sd_model, + outputs=sdxl_model + ) + with gr.Group(), gr.Row(): lora_a_model = gr.Dropdown( label='LoRA model "A" (path to the LoRA A model)', diff --git a/kohya_gui/sd_modeltype.py b/kohya_gui/sd_modeltype.py new file mode 100755 index 00000000..11891bf8 --- /dev/null +++ b/kohya_gui/sd_modeltype.py @@ -0,0 +1,47 @@ +from os.path import isfile +from safetensors import safe_open +import enum + +# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403 + +class ModelType(enum.Enum): + UNKNOWN = 0 + SD1 = 1 + SD2 = 2 + SDXL = 3 + SD3 = 4 + +class SDModelType: + def __init__(self, safetensors_path): + self.model_type = ModelType.UNKNOWN + + if not isfile(safetensors_path): + return + + try: + st = safe_open(filename=safetensors_path, framework="numpy", device="cpu") + def hasKeyPrefix(pfx): + return any(k.startswith(pfx) for k in st.keys()) + + if "model.diffusion_model.x_embedder.proj.weight" in st.keys(): + self.model_type = ModelType.SD3 + elif hasKeyPrefix("conditioner."): + self.model_type = ModelType.SDXL + elif hasKeyPrefix("cond_stage_model.model."): + self.model_type = ModelType.SD2 + elif hasKeyPrefix("model."): + self.model_type = ModelType.SD1 + except: + pass + + def Is_SD1(self): + return self.model_type == ModelType.SD1 + + def Is_SD2(self): + return self.model_type == ModelType.SD2 + + def Is_SDXL(self): + return self.model_type == ModelType.SDXL + + def Is_SD3(self): + return self.model_type == ModelType.SD3