Skip to content

Commit

Permalink
Update trainer to ensure type consistency for train_args and `lora_…
Browse files Browse the repository at this point in the history
…config` (#2181)

* update-trainer

Signed-off-by: helenxie-bit <[email protected]>

* fix typo

Signed-off-by: helenxie-bit <[email protected]>

* reformat with black

Signed-off-by: helenxie-bit <[email protected]>

---------

Signed-off-by: helenxie-bit <[email protected]>
  • Loading branch information
helenxie-bit authored Aug 12, 2024
1 parent 725b09e commit 2561b52
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions sdk/python/kubeflow/trainer/hf_llm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def load_and_preprocess_data(dataset_dir, transformer_type, tokenizer):
def setup_peft_model(model, lora_config):
# Set up the PEFT model
lora_config = LoraConfig(**json.loads(lora_config))
reference_lora_config = LoraConfig()
for key, val in lora_config.__dict__.items():
old_attr = getattr(reference_lora_config, key, None)
if old_attr is not None:
val = type(old_attr)(val)
setattr(lora_config, key, val)

model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
return model
Expand Down Expand Up @@ -158,6 +165,15 @@ def parse_arguments():
logger.info("Starting HuggingFace LLM Trainer")
args = parse_arguments()
train_args = TrainingArguments(**json.loads(args.training_parameters))
reference_train_args = transformers.TrainingArguments(
output_dir=train_args.output_dir
)
for key, val in train_args.to_dict().items():
old_attr = getattr(reference_train_args, key, None)
if old_attr is not None:
val = type(old_attr)(val)
setattr(train_args, key, val)

transformer_type = getattr(transformers, args.transformer_type)

logger.info("Setup model and tokenizer")
Expand Down

0 comments on commit 2561b52

Please sign in to comment.