Skip to content

Commit

Permalink
Add option to CNN baseline to return representation in addition to cl…
Browse files Browse the repository at this point in the history
…assifier output.
  • Loading branch information
tmills committed Sep 24, 2024
1 parent b0de26e commit 8dfcba0
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/cnlpt/BaselineModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def forward(
input_ids=None,
event_tokens=None,
labels=None,
output_hidden_states=False,
**kwargs,
):
embeddings = self.embed(input_ids)
Expand Down Expand Up @@ -103,8 +104,10 @@ def forward(
loss += self.loss_fn[self.task_names[task_ind]](
task_logits, task_labels.type(torch.LongTensor).to(labels.device)
)

return loss, logits
if output_hidden_states:
return loss, logits, fc_in
else:
return loss, logits


class LstmSentenceClassifier(nn.Module):
Expand Down Expand Up @@ -160,5 +163,4 @@ def forward(
loss += self.loss_fn(
task_logits, task_labels.type(torch.LongTensor).to(labels.device)
)

return loss, logits

0 comments on commit 8dfcba0

Please sign in to comment.