Skip to content

Commit

Permalink
Merge pull request #175 from Machine-Learning-for-Medical-Language/ou…
Browse files Browse the repository at this point in the history
…tput-prediction-probs

Output prediction probs
  • Loading branch information
tmills committed Aug 24, 2023
2 parents 3eadd58 + 73a3fad commit 7cd79ea
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/cnlpt/cnlp_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class CnlpTrainingArguments(TrainingArguments):
"help": "Only optimize the bias parameters of the encoder (and the weights of the classifier heads), as proposed in the BitFit paper by Ben Zaken et al. 2021 (https://arxiv.org/abs/2106.10199)"
},
)
output_prob: Optional[bool] = field(
default=False,
metadata={
"help": "If selected, probability scores will be added to the output prediction file for test data."
},
)


cnlpt_models = ["cnn", "lstm", "hier", "cnlpt"]
Expand Down
28 changes: 21 additions & 7 deletions src/cnlpt/cnlp_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def write_predictions_for_dataset(
dataset_ind: int,
output_mode: Dict[str, str],
tokenizer: PreTrainedTokenizer,
output_prob: bool = False,
):
task_labels = dataset.get_labels()
start_ind = end_ind = 0
Expand All @@ -29,14 +30,27 @@ def write_predictions_for_dataset(
)
predictions = trainer.predict(test_dataset=eval_dataset).predictions
for task_ind, task_name in enumerate(dataset.tasks):
if output_prob and output_mode[task_name] != classification:
raise NotImplementedError(
"Writing predictions is not implemented for this output_mode!"
)

if output_mode[task_name] == classification:
task_predictions = np.argmax(predictions[task_ind], axis=1)
for index, item in enumerate(task_predictions):
item = task_labels[task_name][item]
writer.write(
"Task %d (%s) - Index %d - %s\n"
% (task_ind, task_name, index, item)
)
task_predictions = predictions[task_ind]
for index, logits in enumerate(task_predictions):
task_prediction_idx = np.argmax(logits, axis=1)
item = task_labels[task_name][task_prediction_idx]
prob_value = logits[task_prediction_idx]
if output_prob:
writer.write(
"Task %d (%s) - Index %d - %s - %.6f\n"
% (task_ind, task_name, index, item, prob_value)
)
else:
writer.write(
"Task %d (%s) - Index %d - %s\n"
% (task_ind, task_name, index, item)
)
elif output_mode[task_name] == tagging:
task_predictions = np.argmax(predictions[task_ind], axis=2)
tagging_labels = task_labels[task_name]
Expand Down
1 change: 1 addition & 0 deletions src/cnlpt/train_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def compute_metrics_fn(p: EvalPrediction):
dataset_ind,
output_mode,
tokenizer,
output_prob=training_args.output_prob,
)

return eval_results
Expand Down

0 comments on commit 7cd79ea

Please sign in to comment.