diff --git a/src/cnlpt/cnlp_predict.py b/src/cnlpt/cnlp_predict.py index 9896823..b8f43ac 100644 --- a/src/cnlpt/cnlp_predict.py +++ b/src/cnlpt/cnlp_predict.py @@ -14,10 +14,12 @@ logger = logging.getLogger(__name__) + def simple_softmax(x: list): """Softmax values for 1-D score array""" return np.exp(x) / np.sum(np.exp(x), axis=0) + def restructure_prediction( task_names: List[str], raw_prediction: EvalPrediction, @@ -79,7 +81,9 @@ def structure_labels( else: preds = np.argmax(p.predictions[task_ind], axis=1) if output_prob: - prob_values = np.max([simple_softmax(logits) for logits in p.predictions[task_ind]], axis=1) + prob_values = np.max( + [simple_softmax(logits) for logits in p.predictions[task_ind]], axis=1 + ) # for inference if not hasattr(p, "label_ids") or p.label_ids is None: