From b4f9a3536781cd8d5eaec312bf74078e01cffcbf Mon Sep 17 00:00:00 2001 From: Tim Miller Date: Tue, 24 Sep 2024 16:50:20 -0400 Subject: [PATCH] Fix cnn_rest and associated files to use config and use from_pretrained. --- src/cnlpt/BaselineModels.py | 5 +++-- src/cnlpt/api/cnn_rest.py | 15 ++++++--------- src/cnlpt/train_system.py | 30 ++++++++++++++++++++++++------ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/cnlpt/BaselineModels.py b/src/cnlpt/BaselineModels.py index 6e9a88e..a490793 100644 --- a/src/cnlpt/BaselineModels.py +++ b/src/cnlpt/BaselineModels.py @@ -2,10 +2,11 @@ import torch import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch import nn -class CnnSentenceClassifier(nn.Module): +class CnnSentenceClassifier(nn.Module, PyTorchModelHubMixin): def __init__( self, vocab_size, @@ -110,7 +111,7 @@ def forward( return loss, logits -class LstmSentenceClassifier(nn.Module): +class LstmSentenceClassifier(nn.Module, PyTorchModelHubMixin): def __init__( self, vocab_size, diff --git a/src/cnlpt/api/cnn_rest.py b/src/cnlpt/api/cnn_rest.py index 34e4548..0a1f6e5 100644 --- a/src/cnlpt/api/cnn_rest.py +++ b/src/cnlpt/api/cnn_rest.py @@ -38,6 +38,8 @@ logger = logging.getLogger("CNN_REST_Processor") logger.setLevel(logging.DEBUG) +max_seq_length = 128 + @app.on_event("startup") async def startup_event(): @@ -49,15 +51,12 @@ async def startup_event(): num_labels_dict = { task: len(values) for task, values in conf_dict["label_dictionary"].items() } - model = CnnSentenceClassifier( - len(tokenizer), + model = CnnSentenceClassifier.from_pretrained( + model_name, + vocab_size=len(tokenizer), task_names=conf_dict["task_names"], num_labels_dict=num_labels_dict, - embed_dims=conf_dict["cnn_embed_dim"], - num_filters=conf_dict["num_filters"], - filters=conf_dict["filters"], ) - model.load_state_dict(torch.load(join(model_name, "pytorch_model.bin"))) app.state.model = model.to("cuda") app.state.tokenizer = tokenizer @@ -67,9 +66,7 @@ async def startup_event(): @app.post("/cnn/classify") async def process(doc: UnannotatedDocument): instances = [doc.doc_text] - dataset = get_dataset( - instances, app.state.tokenizer, max_length=app.state.conf_dict["max_seq_length"] - ) + dataset = get_dataset(instances, app.state.tokenizer, max_length=max_seq_length) _, logits = app.state.model.forward( input_ids=torch.LongTensor(dataset["input_ids"]).to("cuda"), attention_mask=torch.LongTensor(dataset["attention_mask"]).to("cuda"), diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index f484758..2c12a44 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -524,6 +524,7 @@ def main( bias_fit=training_args.bias_fit, ) + model_type = type(model) output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") if training_args.do_train: # TODO: This assumes that if there are multiple training sets, they all have the same length, but @@ -660,7 +661,12 @@ def compute_metrics_fn(p: EvalPrediction): ), "w", ) as f: - json.dump(model_args.to_dict(), f) + config_dict = model_args.to_dict() + config_dict[ + "label_dictionary" + ] = dataset.get_labels() + config_dict["task_names"] = task_names + json.dump(config_dict, f) for task_ind, task_name in enumerate(metrics): with open(output_eval_file, "w") as writer: # logger.info("***** Eval results for task %s *****" % (task_name)) @@ -712,10 +718,13 @@ def compute_metrics_fn(p: EvalPrediction): trainer.save_model() tokenizer.save_pretrained(training_args.output_dir) if model_name == "cnn" or model_name == "lstm": + config_dict = model_args.to_dict() + config_dict["label_dictionary"] = dataset.get_labels() + config_dict["task_names"] = task_names with open( os.path.join(training_args.output_dir, "config.json"), "w" ) as f: - json.dump(model_args, f) + json.dump(config_dict, f) # Evaluation eval_results = {} @@ -751,10 +760,19 @@ def compute_metrics_fn(p: EvalPrediction): writer.write(f"{key} : {value} \n") # here we probably want separate predictions for each dataset: if training_args.load_best_model_at_end: - model.load_state_dict( - torch.load(join(training_args.output_dir, "pytorch_model.bin")) - ) # load best model - trainer = Trainer( # maake trainer from best model + model_path = training_args.output_dir + if model_name == "cnn" or model_name == "lstm": + # non-HF models need manually passed config args + model = model_type.from_pretrained( + model_path, + vocab_size=len(tokenizer), + task_names=task_names, + num_labels_dict=num_labels, + ) + else: + model = model_type.from_pretrained(model_path) + + trainer = Trainer( # make trainer from best model model=model, args=training_args, train_dataset=dataset.processed_dataset.get("train", None),