Skip to content

Commit

Permalink
Merge pull request #2690 from b-fission/autodetect-modeltype
Browse files Browse the repository at this point in the history
Auto-detect model type for safetensors files
  • Loading branch information
bmaltais committed Aug 9, 2024
2 parents 93b1c07 + c0966bc commit 630cec8
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
8 changes: 8 additions & 0 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from easygui import msgbox, ynbox
from typing import Optional
from .custom_logging import setup_logging
from .sd_modeltype import SDModelType

import os
import re
Expand Down Expand Up @@ -1014,6 +1015,13 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = gr.Checkbox(visible=True)
sdxl = gr.Checkbox(visible=True)

# Auto-detect model type if safetensors file path is given
if pretrained_model_name_or_path.lower().endswith(".safetensors"):
detect = SDModelType(pretrained_model_name_or_path)
v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True)
sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True)
#TODO: v_parameterization

# If a refresh method is provided, use it to update the choices for the Dropdown widget
if refresh_method is not None:
args = dict(
Expand Down
14 changes: 14 additions & 0 deletions kohya_gui/extract_lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from .custom_logging import setup_logging
from .sd_modeltype import SDModelType

# Set up logging
log = setup_logging()
Expand Down Expand Up @@ -337,6 +338,19 @@ def change_sdxl(sdxl):
outputs=[load_tuned_model_to, load_original_model_to],
)

#secondary event on model_tuned for auto-detection of v2/SDXL
def change_modeltype_model_tuned(path):
detect = SDModelType(path)
v2 = gr.Checkbox(value=detect.Is_SD2())
sdxl = gr.Checkbox(value=detect.Is_SDXL())
return v2, sdxl

model_tuned.change(
change_modeltype_model_tuned,
inputs=model_tuned,
outputs=[v2, sdxl]
)

extract_button = gr.Button("Extract LoRA model")

extract_button.click(
Expand Down
8 changes: 8 additions & 0 deletions kohya_gui/merge_lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
create_refresh_button, setup_environment
)
from .custom_logging import setup_logging
from .sd_modeltype import SDModelType

# Set up logging
log = setup_logging()
Expand Down Expand Up @@ -145,6 +146,13 @@ def list_save_to(path):
show_progress=False,
)

#secondary event on sd_model for auto-detection of SDXL
sd_model.change(
lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()),
inputs=sd_model,
outputs=sdxl_model
)

with gr.Group(), gr.Row():
lora_a_model = gr.Dropdown(
label='LoRA model "A" (path to the LoRA A model)',
Expand Down
47 changes: 47 additions & 0 deletions kohya_gui/sd_modeltype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from os.path import isfile
from safetensors import safe_open
import enum

# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403

class ModelType(enum.Enum):
UNKNOWN = 0
SD1 = 1
SD2 = 2
SDXL = 3
SD3 = 4

class SDModelType:
def __init__(self, safetensors_path):
self.model_type = ModelType.UNKNOWN

if not isfile(safetensors_path):
return

try:
st = safe_open(filename=safetensors_path, framework="numpy", device="cpu")
def hasKeyPrefix(pfx):
return any(k.startswith(pfx) for k in st.keys())

if "model.diffusion_model.x_embedder.proj.weight" in st.keys():
self.model_type = ModelType.SD3
elif hasKeyPrefix("conditioner."):
self.model_type = ModelType.SDXL
elif hasKeyPrefix("cond_stage_model.model."):
self.model_type = ModelType.SD2
elif hasKeyPrefix("model."):
self.model_type = ModelType.SD1
except:
pass

def Is_SD1(self):
return self.model_type == ModelType.SD1

def Is_SD2(self):
return self.model_type == ModelType.SD2

def Is_SDXL(self):
return self.model_type == ModelType.SDXL

def Is_SD3(self):
return self.model_type == ModelType.SD3

0 comments on commit 630cec8

Please sign in to comment.