From d2eef147b5e7b811a6404c86233968c1213a44a0 Mon Sep 17 00:00:00 2001 From: zetavg Date: Sat, 22 Apr 2023 04:59:59 +0800 Subject: [PATCH] finetune: load template and dataset from model --- llama_lora/ui/finetune_ui.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/llama_lora/ui/finetune_ui.py b/llama_lora/ui/finetune_ui.py index 54e14c2..3928a32 100644 --- a/llama_lora/ui/finetune_ui.py +++ b/llama_lora/ui/finetune_ui.py @@ -621,6 +621,7 @@ def handle_continue_from_model_change(model_name): def handle_load_params_from_model( model_name, + template, load_dataset_from, dataset_from_data_dir, max_seq_length, evaluate_data_count, micro_batch_size, @@ -654,6 +655,20 @@ def handle_load_params_from_model( lora_model_directory_path = os.path.join( lora_models_directory_path, model_name) + try: + with open(os.path.join(lora_model_directory_path, "info.json"), "r") as f: + info = json.load(f) + if isinstance(info, dict): + model_prompt_template = info.get("prompt_template") + if model_prompt_template: + template = model_prompt_template + model_dataset_name = info.get("dataset_name") + if model_dataset_name and isinstance(model_dataset_name, str) and not model_dataset_name.startswith("N/A"): + load_dataset_from = "Data Dir" + dataset_from_data_dir = model_dataset_name + except FileNotFoundError: + pass + data = {} possible_files = ["finetune_params.json", "finetune_args.json"] for file in possible_files: @@ -747,6 +762,7 @@ def handle_load_params_from_model( return ( gr.Markdown.update(value=message, visible=has_message), + template, load_dataset_from, dataset_from_data_dir, max_seq_length, evaluate_data_count, micro_batch_size, @@ -1231,9 +1247,9 @@ def finetune_ui(): things_that_might_timeout.append( load_params_from_model_btn.click( fn=handle_load_params_from_model, - inputs=[continue_from_model] + finetune_args + + inputs=[continue_from_model] + [template, load_dataset_from, dataset_from_data_dir] + finetune_args + [lora_target_module_choices, lora_modules_to_save_choices], - outputs=[load_params_from_model_message] + finetune_args + + outputs=[load_params_from_model_message] + [template, load_dataset_from, dataset_from_data_dir] + finetune_args + [lora_target_module_choices, lora_modules_to_save_choices] ) )