Skip to content

Commit

Permalink
Merge pull request #164 from Machine-Learning-for-Medical-Language/sa…
Browse files Browse the repository at this point in the history
…ve_load_fix

save config.json in training; actually loads best model before prediction
  • Loading branch information
tmills committed Aug 3, 2023
2 parents 7c74766 + 7b69e58 commit cfa8959
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
16 changes: 15 additions & 1 deletion src/cnlpt/cnlp_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from typing import Callable, Dict, Optional, List, Union, Any
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from enum import Enum
from transformers import TrainingArguments

@dataclass
Expand Down Expand Up @@ -163,3 +164,16 @@ class ModelArguments:
"document-level transformer layers"
}
)
def to_dict(self):
# adapted from transformers.TrainingArguments.to_dict()
# filter out fields that are defined as field(init=False)
d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}

for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
20 changes: 19 additions & 1 deletion src/cnlpt/train_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from transformers.file_utils import CONFIG_NAME
from huggingface_hub import hf_hub_url

import sys
sys.path.append(os.path.join(os.getcwd()))
from .cnlp_processors import tagging, relex, classification
from .cnlp_data import ClinicalNlpDataset, DataTrainingArguments
from .cnlp_metrics import cnlp_compute_metrics
Expand All @@ -56,6 +58,7 @@
Trainer,
set_seed,
)
import json

AutoConfig.register("cnlpt", CnlpConfig)

Expand Down Expand Up @@ -431,6 +434,9 @@ def compute_metrics_fn(p: EvalPrediction):
if training_args.do_train:
trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
if model_name == 'cnn' or model_name == 'lstm':
with open(os.path.join(training_args.output_dir, 'config.json'), 'w') as f:
json.dump(model_args.to_dict(), f)
for task_ind,task_name in enumerate(metrics):
with open(output_eval_file, "w") as writer:
# logger.info("***** Eval results for task %s *****" % (task_name))
Expand Down Expand Up @@ -469,6 +475,9 @@ def compute_metrics_fn(p: EvalPrediction):
if trainer.is_world_process_zero():
trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
if model_name == 'cnn' or model_name == 'lstm':
with open(os.path.join(training_args.output_dir, 'config.json'), 'w') as f:
json.dump(model_args, f)

# Evaluation
eval_results = {}
Expand All @@ -493,7 +502,16 @@ def compute_metrics_fn(p: EvalPrediction):
writer.write("%s = %s\n" % (key, value))

# here we probably want separate predictions for each dataset:

if training_args.load_best_model_at_end:
model.load_state_dict(torch.load(join(training_args.output_dir, 'pytorch_model.bin'))) # load best model
trainer = Trainer( # maake trainer from best model
model=model,
args=training_args,
train_dataset=dataset.processed_dataset.get('train', None),
eval_dataset=dataset.processed_dataset.get('validation', None),
compute_metrics=build_compute_metrics_fn(task_names, model, dataset),
)
# use trainer to predict
for dataset_ind,dataset_path in enumerate(data_args.data_dir):
subdir = os.path.split(dataset_path.rstrip('/'))[1]
output_eval_predictions_file = os.path.join(training_args.output_dir, f'eval_predictions_%s_%d.txt' % (subdir, dataset_ind))
Expand Down

0 comments on commit cfa8959

Please sign in to comment.