From 73f6cdc9e58941363f35586305408d180d5dd4d4 Mon Sep 17 00:00:00 2001 From: Tim Miller Date: Mon, 17 Jul 2023 19:12:56 -0400 Subject: [PATCH] Fix to fine-tuning hierarchically pre-trained transformers so they can be loaded later for prediction. --- .gitignore | 1 + src/cnlpt/HierarchicalTransformer.py | 21 ++++++++++++++++++--- src/cnlpt/api/cnlp_rest.py | 1 + src/cnlpt/train_system.py | 8 ++++++-- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 4f6b83e8..f87f8002 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,4 @@ cache/ outputs/ temp/ wandb/ +logs/ diff --git a/src/cnlpt/HierarchicalTransformer.py b/src/cnlpt/HierarchicalTransformer.py index d3e1abed..b20f08ed 100644 --- a/src/cnlpt/HierarchicalTransformer.py +++ b/src/cnlpt/HierarchicalTransformer.py @@ -228,7 +228,19 @@ def __init__( self.config.layer, self.config.hier_head_config["n_layers"] )) - self.layer = self.config.layer + + if self.config.layer < 0: + self.layer = self.config.hier_head_config["n_layers"] + self.config.layer + if self.layer < 0: + raise ValueError("The layer specified (%d) is a negative value which is larger than the actual number of layers %d" % ( + self.config.layer, + self.config.hier_head_config["n_layers"] + )) + else: + self.layer = self.config.layer + + if self.layer == 0: + raise ValueError("The classifier layer derived is 0 which is ambiguous -- there is no usable 0th layer in a hierarchical model. Enter a value for the layer argument that at least 1 (use one layer) or -1 (use the final layer)") # This would seem to be redundant with the label list, which maps from tasks to labels, # but this version is ordered. This will allow the user to specify an order for any methods @@ -397,7 +409,7 @@ def forward( # extract first Documents as rep. (B, hidden_size) doc_rep = chunks_reps[:, 0, :] - total_loss = 0 + total_loss = None for task_ind, task_name in enumerate(self.tasks): if not self.class_weights[task_name] is None: class_weights = torch.FloatTensor(self.class_weights[task_name]).to(self.device) @@ -412,7 +424,10 @@ def forward( if labels is not None: task_labels = labels[:, task_ind] task_loss = loss_fct(task_logits, task_labels.type(torch.LongTensor).to(labels.device)) - total_loss += task_loss + if total_loss is None: + total_loss = task_loss + else: + total_loss += task_loss if self.training: diff --git a/src/cnlpt/api/cnlp_rest.py b/src/cnlpt/api/cnlp_rest.py index 87b7a425..24f3ca26 100644 --- a/src/cnlpt/api/cnlp_rest.py +++ b/src/cnlpt/api/cnlp_rest.py @@ -74,6 +74,7 @@ def initialize_cnlpt_model(app, model_name, cuda=True, batch_size=8): AutoModel.register(CnlpConfig, CnlpModelForClassification) config = AutoConfig.from_pretrained(model_name) + app.state.config = config app.state.tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) model = CnlpModelForClassification.from_pretrained(model_name, cache_dir=os.getenv('HF_CACHE'), config=config) diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index 60ff0f4c..e92fb247 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -234,7 +234,12 @@ def main( config = AutoConfig.from_pretrained( encoder_name, cache_dir=model_args.cache_dir, + layer=model_args.layer ) + config.finetuning_task = data_args.task_name + config.relations = relations + config.tagger = tagger + config.label_dictionary = {} # this gets filled in later ## TODO: check if user overwrote parameters in command line that could change behavior of the model and warn #if data_args.chunk_len is not None: @@ -244,8 +249,7 @@ def main( model.remove_task_classifiers() for task in data_args.task_name: - if task not in config.finetuning_task: - model.add_task_classifier(task, dataset.get_labels()[task]) + model.add_task_classifier(task, dataset.get_labels()[task]) model.set_class_weights(dataset.class_weights) else: