Skip to content

Commit

Permalink
Fix to fine-tuning hierarchically pre-trained transformers so they ca…
Browse files Browse the repository at this point in the history
…n be loaded later for prediction.
  • Loading branch information
tmills committed Jul 17, 2023
1 parent eff5a39 commit 73f6cdc
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ cache/
outputs/
temp/
wandb/
logs/
21 changes: 18 additions & 3 deletions src/cnlpt/HierarchicalTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,19 @@ def __init__(
self.config.layer,
self.config.hier_head_config["n_layers"]
))
self.layer = self.config.layer

if self.config.layer < 0:
self.layer = self.config.hier_head_config["n_layers"] + self.config.layer
if self.layer < 0:
raise ValueError("The layer specified (%d) is a negative value which is larger than the actual number of layers %d" % (
self.config.layer,
self.config.hier_head_config["n_layers"]
))
else:
self.layer = self.config.layer

if self.layer == 0:
raise ValueError("The classifier layer derived is 0 which is ambiguous -- there is no usable 0th layer in a hierarchical model. Enter a value for the layer argument that at least 1 (use one layer) or -1 (use the final layer)")

# This would seem to be redundant with the label list, which maps from tasks to labels,
# but this version is ordered. This will allow the user to specify an order for any methods
Expand Down Expand Up @@ -397,7 +409,7 @@ def forward(
# extract first Documents as rep. (B, hidden_size)
doc_rep = chunks_reps[:, 0, :]

total_loss = 0
total_loss = None
for task_ind, task_name in enumerate(self.tasks):
if not self.class_weights[task_name] is None:
class_weights = torch.FloatTensor(self.class_weights[task_name]).to(self.device)
Expand All @@ -412,7 +424,10 @@ def forward(
if labels is not None:
task_labels = labels[:, task_ind]
task_loss = loss_fct(task_logits, task_labels.type(torch.LongTensor).to(labels.device))
total_loss += task_loss
if total_loss is None:
total_loss = task_loss
else:
total_loss += task_loss


if self.training:
Expand Down
1 change: 1 addition & 0 deletions src/cnlpt/api/cnlp_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def initialize_cnlpt_model(app, model_name, cuda=True, batch_size=8):
AutoModel.register(CnlpConfig, CnlpModelForClassification)

config = AutoConfig.from_pretrained(model_name)
app.state.config = config
app.state.tokenizer = AutoTokenizer.from_pretrained(model_name,
config=config)
model = CnlpModelForClassification.from_pretrained(model_name, cache_dir=os.getenv('HF_CACHE'), config=config)
Expand Down
8 changes: 6 additions & 2 deletions src/cnlpt/train_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,12 @@ def main(
config = AutoConfig.from_pretrained(
encoder_name,
cache_dir=model_args.cache_dir,
layer=model_args.layer
)
config.finetuning_task = data_args.task_name
config.relations = relations
config.tagger = tagger
config.label_dictionary = {} # this gets filled in later

## TODO: check if user overwrote parameters in command line that could change behavior of the model and warn
#if data_args.chunk_len is not None:
Expand All @@ -244,8 +249,7 @@ def main(

model.remove_task_classifiers()
for task in data_args.task_name:
if task not in config.finetuning_task:
model.add_task_classifier(task, dataset.get_labels()[task])
model.add_task_classifier(task, dataset.get_labels()[task])
model.set_class_weights(dataset.class_weights)

else:
Expand Down

0 comments on commit 73f6cdc

Please sign in to comment.