diff --git a/src/cnlpt/cnlp_args.py b/src/cnlpt/cnlp_args.py index 4714f0ae..426e385c 100644 --- a/src/cnlpt/cnlp_args.py +++ b/src/cnlpt/cnlp_args.py @@ -42,6 +42,12 @@ class CnlpTrainingArguments(TrainingArguments): "help": "Only optimize the bias parameters of the encoder (and the weights of the classifier heads), as proposed in the BitFit paper by Ben Zaken et al. 2021 (https://arxiv.org/abs/2106.10199)" }, ) + output_prob: Optional[bool] = field( + default=False, + metadata={ + "help": "If selected, probability scores will be added to the output prediction file for test data." + }, + ) cnlpt_models = ["cnn", "lstm", "hier", "cnlpt"] diff --git a/src/cnlpt/cnlp_predict.py b/src/cnlpt/cnlp_predict.py index 51c42d21..048e9daf 100644 --- a/src/cnlpt/cnlp_predict.py +++ b/src/cnlpt/cnlp_predict.py @@ -16,6 +16,7 @@ def write_predictions_for_dataset( dataset_ind: int, output_mode: Dict[str, str], tokenizer: PreTrainedTokenizer, + output_prob: bool = False, ): task_labels = dataset.get_labels() start_ind = end_ind = 0 @@ -29,14 +30,27 @@ def write_predictions_for_dataset( ) predictions = trainer.predict(test_dataset=eval_dataset).predictions for task_ind, task_name in enumerate(dataset.tasks): + if output_prob and output_mode[task_name] != classification: + raise NotImplementedError( + "Writing predictions is not implemented for this output_mode!" + ) + if output_mode[task_name] == classification: - task_predictions = np.argmax(predictions[task_ind], axis=1) - for index, item in enumerate(task_predictions): - item = task_labels[task_name][item] - writer.write( - "Task %d (%s) - Index %d - %s\n" - % (task_ind, task_name, index, item) - ) + task_predictions = predictions[task_ind] + for index, logits in enumerate(task_predictions): + task_prediction_idx = np.argmax(logits, axis=1) + item = task_labels[task_name][task_prediction_idx] + prob_value = logits[task_prediction_idx] + if output_prob: + writer.write( + "Task %d (%s) - Index %d - %s - %.6f\n" + % (task_ind, task_name, index, item, prob_value) + ) + else: + writer.write( + "Task %d (%s) - Index %d - %s\n" + % (task_ind, task_name, index, item) + ) elif output_mode[task_name] == tagging: task_predictions = np.argmax(predictions[task_ind], axis=2) tagging_labels = task_labels[task_name] diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index 197896c8..1d9a0641 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -652,6 +652,7 @@ def compute_metrics_fn(p: EvalPrediction): dataset_ind, output_mode, tokenizer, + output_prob=training_args.output_prob, ) return eval_results