diff --git a/src/cnlpt/cnlp_args.py b/src/cnlpt/cnlp_args.py index bbb84971..5fbadba4 100644 --- a/src/cnlpt/cnlp_args.py +++ b/src/cnlpt/cnlp_args.py @@ -26,6 +26,12 @@ class CnlpTrainingArguments(TrainingArguments): bias_fit: bool = field( default=False, metadata={"help": "Only optimize the bias parameters of the encoder (and the weights of the classifier heads), as proposed in the BitFit paper by Ben Zaken et al. 2021 (https://arxiv.org/abs/2106.10199)"} ) + model_selection_score: str = field( + default=None, metadata={"help": "Score to use in evaluation. Should be one of acc, f1, acc_and_f1, recall, or precision."} + ) + model_selection_label: Union[int, str, List[int], List[str]] = field( + default=None, metadata={"help": "Class whose score should be used in evalutation. Should be an integer if scores are indexed, or a string if they are labeled by name."} + ) cnlpt_models = ['cnn', 'lstm', 'hier', 'cnlpt'] diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index bfa09762..3f91acc5 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -421,7 +421,33 @@ def compute_metrics_fn(p: EvalPrediction): dataset.output_modes[task_name], dataset.tasks_to_labels[task_name]) # FIXME - Defaulting to accuracy for model selection score, when it should be task-specific - task_scores.append( metrics[task_name].get('one_score', np.mean(metrics[task_name].get('f1')))) + if training_args.model_selection_score is not None: + score = metrics[task_name].get('one_score', metrics[task_name].get(training_args.model_selection_score)) + if isinstance(training_args.model_selection_label, int): + task_scores.append(score[training_args.model_selection_label]) + # we can only get the scores in list form, + # so we have to maneuver a bit to get the sccore + # if the label is provided in string form + elif isinstance(training_args.model_selection_label, str): + index = dataset.tasks_to_labels[task_name].index(training_args.model_selection_label) + task_scores.append(score[index]) + elif isinstance(training_args.model_selection_label, list) or isinstance(training_args.model_selection_label, tuple): + scores = [] + for label in training_args.model_selection_label: + if isinstance(label, int): + scores.append(score[label]) + elif isinstance(label, str): + index = dataset.tasks_to_labels[task_name].index(label) + scores.append(score[index]) + else: + raise RuntimeError(f"Unrecognized label type: {type(label)}") + task_scores.append(np.mean(scores)) + elif training_args.model_selection_label is None: + task_scores.append(metrics[task_name].get('one_score', np.mean(score))) + else: + raise RuntimeError(f"Unrecognized label type: {type(training_args.model_selection_label)}") + else: # same default as in 0.6.0 + task_scores.append(metrics[task_name].get('one_score', np.mean(metrics[task_name].get('f1')))) #task_scores.append(processor.get_one_score(metrics.get(task_name, metrics.get(task_name.split('-')[0], None)))) one_score = sum(task_scores) / len(task_scores)