Skip to content

Commit

Permalink
Slightly more well-formed
Browse files Browse the repository at this point in the history
  • Loading branch information
etgld committed Aug 10, 2023
1 parent 1f3bd87 commit ef65f5f
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/cnlpt/cnlp_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def get_pred_out_string(
instance_tokens = [*filter(None, instance.split())]

return dict_to_str(type2spans, instance_tokens)

pred_span_dictionaries = (
types2spans(pred, torch_label)
for pred, torch_label in zip(resolved_predictions, torch_labels)
Expand Down Expand Up @@ -378,13 +378,13 @@ def get_pred_out_string(
def get_relex_prints(
task_name: str,
relex_labels: List[str],
ground_truths: Union[None, np.ndarray], # List[str],
ground_truths: Union[None, np.ndarray],
task_predictions: np.ndarray,
torch_labels: np.ndarray,
) -> List[str]:
Cell = Tuple[int, int, int]

resolved_predictions = task_predictions # np.argmax(task_predictions, axis=3)
resolved_predictions = task_predictions
none_index = relex_labels.index("None") if "None" in relex_labels else -1

# thought we'd filtered them out but apparently not
Expand Down Expand Up @@ -418,12 +418,18 @@ def normalize_cells(
)

# adding the diagonal back in...
final_reduced_matrix = np.array(
[
np.insert(row, row_idx, none_index, axis=0)
for row_idx, row in enumerate(reduced_matrix)
]
final_reduced_matrix = (
np.array(
[
np.insert(row, row_idx, none_index, axis=0)
for row_idx, row in enumerate(reduced_matrix)
]
)
if len(reduced_matrix) > 0
else np.zeros((1, 1)) + none_index
)

assert final_reduced_matrix.shape[0] == final_reduced_matrix.shape[1]
return invalid_inds, final_reduced_matrix

def find_disagreements(
Expand Down

0 comments on commit ef65f5f

Please sign in to comment.