Skip to content

Commit

Permalink
Report sequences with uni-valued labels when training
Browse files Browse the repository at this point in the history
  • Loading branch information
althonos committed Feb 27, 2024
1 parent 422fe7d commit f84dda5
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions gecco/crf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,11 @@ def fit(
# extract features and labels
feats: List[Dict[str, bool]] = extract_features(sequence)
labels: List[str] = extract_labels(sequence)
if all(label == "0" for label in labels):
raise ValueError(f"only negative labels found in sequence {sequence[0].source.id!r}")
elif all(label == "1" for label in labels):
raise ValueError(f"only positive labels found in sequence {sequence[0].source.id!r}")

# check we have as many observations as we have labels
if len(feats) != len(labels):
raise ValueError("different number of features and labels found, something is wrong")
Expand All @@ -366,12 +371,6 @@ def fit(
training_features.append(feats[win])
training_labels.append(labels[win])

# check labels
if all(label == "1" for y in training_labels for label in y):
raise ValueError("only positives labels found, something is wrong.")
elif all(label == "0" for y in training_labels for label in y):
raise ValueError("only negative labels found, something is wrong.")

# fit the model
self.model = model = sklearn_crfsuite.CRF(**self._options)
model.fit(training_features, training_labels)
Expand Down

0 comments on commit f84dda5

Please sign in to comment.