Skip to content

Commit

Permalink
Reformat: black
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjininfo committed Aug 19, 2024
1 parent 2af1e2a commit 9376200
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/cnlpt/cnlp_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

logger = logging.getLogger(__name__)


def simple_softmax(x: list):
"""Softmax values for 1-D score array"""
return np.exp(x) / np.sum(np.exp(x), axis=0)


def restructure_prediction(
task_names: List[str],
raw_prediction: EvalPrediction,
Expand Down Expand Up @@ -79,7 +81,9 @@ def structure_labels(
else:
preds = np.argmax(p.predictions[task_ind], axis=1)
if output_prob:
prob_values = np.max([simple_softmax(logits) for logits in p.predictions[task_ind]], axis=1)
prob_values = np.max(
[simple_softmax(logits) for logits in p.predictions[task_ind]], axis=1
)

# for inference
if not hasattr(p, "label_ids") or p.label_ids is None:
Expand Down

0 comments on commit 9376200

Please sign in to comment.