Skip to content

Commit

Permalink
finetune: load template and dataset from model
Browse files Browse the repository at this point in the history
  • Loading branch information
zetavg committed Apr 21, 2023
1 parent 8177d08 commit d2eef14
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions llama_lora/ui/finetune_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def handle_continue_from_model_change(model_name):

def handle_load_params_from_model(
model_name,
template, load_dataset_from, dataset_from_data_dir,
max_seq_length,
evaluate_data_count,
micro_batch_size,
Expand Down Expand Up @@ -654,6 +655,20 @@ def handle_load_params_from_model(
lora_model_directory_path = os.path.join(
lora_models_directory_path, model_name)

try:
with open(os.path.join(lora_model_directory_path, "info.json"), "r") as f:
info = json.load(f)
if isinstance(info, dict):
model_prompt_template = info.get("prompt_template")
if model_prompt_template:
template = model_prompt_template
model_dataset_name = info.get("dataset_name")
if model_dataset_name and isinstance(model_dataset_name, str) and not model_dataset_name.startswith("N/A"):
load_dataset_from = "Data Dir"
dataset_from_data_dir = model_dataset_name
except FileNotFoundError:
pass

data = {}
possible_files = ["finetune_params.json", "finetune_args.json"]
for file in possible_files:
Expand Down Expand Up @@ -747,6 +762,7 @@ def handle_load_params_from_model(

return (
gr.Markdown.update(value=message, visible=has_message),
template, load_dataset_from, dataset_from_data_dir,
max_seq_length,
evaluate_data_count,
micro_batch_size,
Expand Down Expand Up @@ -1231,9 +1247,9 @@ def finetune_ui():
things_that_might_timeout.append(
load_params_from_model_btn.click(
fn=handle_load_params_from_model,
inputs=[continue_from_model] + finetune_args +
inputs=[continue_from_model] + [template, load_dataset_from, dataset_from_data_dir] + finetune_args +
[lora_target_module_choices, lora_modules_to_save_choices],
outputs=[load_params_from_model_message] + finetune_args +
outputs=[load_params_from_model_message] + [template, load_dataset_from, dataset_from_data_dir] + finetune_args +
[lora_target_module_choices, lora_modules_to_save_choices]
)
)
Expand Down

0 comments on commit d2eef14

Please sign in to comment.