diff --git a/src/cnlpt/BaselineModels.py b/src/cnlpt/BaselineModels.py index 04e6ede..6e9a88e 100644 --- a/src/cnlpt/BaselineModels.py +++ b/src/cnlpt/BaselineModels.py @@ -67,6 +67,7 @@ def forward( input_ids=None, event_tokens=None, labels=None, + output_hidden_states=False, **kwargs, ): embeddings = self.embed(input_ids) @@ -103,8 +104,10 @@ def forward( loss += self.loss_fn[self.task_names[task_ind]]( task_logits, task_labels.type(torch.LongTensor).to(labels.device) ) - - return loss, logits + if output_hidden_states: + return loss, logits, fc_in + else: + return loss, logits class LstmSentenceClassifier(nn.Module): @@ -160,5 +163,4 @@ def forward( loss += self.loss_fn( task_logits, task_labels.type(torch.LongTensor).to(labels.device) ) - return loss, logits