Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Support fleiss kappa (#126)
Browse files Browse the repository at this point in the history
* Refactor runner

* Add fleiss kappa

* Refactor None handling
  • Loading branch information
yujonglee authored Sep 20, 2023
1 parent 136821b commit d6796d1
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 37 deletions.
18 changes: 4 additions & 14 deletions fastrepl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,28 @@ def __init__(
dataset: Dataset,
output_feature="result",
) -> None:
self._evaluator = evaluator
self._dataset = dataset

self._output_feature = output_feature
self._input_features = [
param for param in inspect.signature(evaluator.run).parameters.keys()
]
self._output_feature = output_feature

def _validate(
self,
evaluator: fastrepl.Evaluator,
dataset: Dataset,
) -> None:
if any(feature not in dataset.column_names for feature in self._input_features):
eval_name = type(evaluator).__name__

raise ValueError( # TODO: custom error
f"{eval_name} requires {self._input_features}, but the provided dataset has {dataset.column_names}"
)

def _run_eval(self, **kwargs) -> Optional[Any]:
return self._evaluator.run(**kwargs)
self._evaluator = evaluator
self._dataset = dataset

def _run(self, progress: Progress, task_id: TaskID) -> List[Optional[Any]]:
results = []

with ThreadPool(NUM_THREADS) as pool:
futures = [
pool.apply_async(
self._run_eval,
self._evaluator.run,
kwds={
feature: value
for feature, value in zip(self._input_features, values)
Expand All @@ -71,8 +63,6 @@ def _run(self, progress: Progress, task_id: TaskID) -> List[Optional[Any]]:
return results

def run(self, num=1) -> Dataset:
self._validate(self._evaluator, self._dataset)

with Progress() as progress:
msg = "[cyan]Processing..."
task_id = progress.add_task(msg, total=len(self._dataset) * num)
Expand Down
45 changes: 26 additions & 19 deletions fastrepl/utils/kappa.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
from typing import List, Any
from typing import List, Any, cast

from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import LabelEncoder
from statsmodels.stats.inter_rater import cohens_kappa
from statsmodels.stats.inter_rater import cohens_kappa, fleiss_kappa, aggregate_raters


def kappa(*predictions: List[Any]) -> float:
def kappa(predictions: List[List[Any]]) -> float:
if len(predictions) < 2:
raise ValueError
if len(predictions) > 2:
raise NotImplementedError

# TODO: We only support cohens_kappa for now
assert len(predictions) == 2

if len(predictions[0]) == 0 or len(predictions[1]) == 0:
if any(len(ps) == 0 for ps in predictions):
raise ValueError

# TODO: hacky none-handling
if isinstance(predictions[0][0], str):
# TODO: workaround for None
a = ["" if p is None else p for p in predictions[0]]
b = ["" if p is None else p for p in predictions[1]]
for ps in predictions:
ps = ["" if p is None else p for p in ps]

le = LabelEncoder()
le.fit(list(set(a + b)))
le.fit(list(set([p for ps in predictions for p in ps])))

a, b = le.transform(a), le.transform(b)
predictions = [le.transform(p) for p in predictions]
else:
# TODO: workaround for None
a = [-1 if p is None else p for p in predictions[0]]
b = [-1 if p is None else p for p in predictions[1]]
predictions = [[-1 if p is None else p for p in ps] for ps in predictions]

if len(predictions) == 2:
return _cohens_kappa(predictions[0], predictions[1])
return _fleiss_kappa(predictions)


def _cohens_kappa(a: List[Any], b: List[Any]) -> float:
return cohens_kappa(
table=confusion_matrix(a, b),
return_results=False,
)


return cohens_kappa(table=confusion_matrix(a, b), return_results=False)
def _fleiss_kappa(predictions: List[List[Any]]) -> float:
input = list(zip(*predictions)) # transpose
table, _ = aggregate_raters(input)
return fleiss_kappa(table)
55 changes: 52 additions & 3 deletions tests/unit/test_kappa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from sklearn.metrics import confusion_matrix
from statsmodels.stats.inter_rater import cohens_kappa, fleiss_kappa, aggregate_raters

from fastrepl.utils import kappa
Expand All @@ -21,14 +20,37 @@ def test_basic2(self):


class TestFleissKappa:
def test_aggregate_raters(self):
def test_aggregate_raters_1(self):
table, categories = aggregate_raters([[0, 1, 2], [1, 0, 1]])

assert (table == [[1, 1, 1], [1, 2, 0]]).all()
assert (categories == [0, 1, 2]).all()

def test_aggregate_raters_2(self):
table, categories = aggregate_raters(
[[0, 1, 2], [1, 0, 1], [2, 2, 0], [1, 0, 2]]
)

assert (table == [[1, 1, 1], [1, 2, 0], [1, 0, 2], [1, 1, 1]]).all()
assert (categories == [0, 1, 2]).all()

def test_basic(self):
table, _ = aggregate_raters(
[ # Note that this is result of 3 raters
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[3, 3, 2],
[0, 0, 0],
[0, 0, 0],
[1, 1, 1],
]
)
result = fleiss_kappa(table)

assert result == pytest.approx(0.84516, abs=1e-5)


@pytest.mark.parametrize(
"predictions, result",
Expand All @@ -55,7 +77,34 @@ def test_aggregate_raters(self):
],
0.499,
),
]
+ [
([[1, None], [1, 2], [None, 2]], 0),
([[1, 2, 3], [1, 2, 3], [1, 2, 3]], 1.0),
([[1, 2, 3], [1, 1, 3], [1, 3, 3]], 0.437),
(
[
[1, 1, 1, 1, 3, 0, 0, 1],
[1, 1, 1, 1, 3, 0, 0, 1],
[1, 1, 1, 1, 2, 0, 0, 1],
],
0.845,
),
(
[
["POSITIVE", "NEGATIVE", "POSITIVE"],
["NEGATIVE", "POSITIVE", "NEGATIVE"],
],
-0.799,
),
(
[
["POSITIVE", "NEGATIVE", "POSITIVE"],
["POSITIVE", "NEGATIVE", "POSITIVE"],
],
1.0,
),
],
)
def test_kappa(predictions, result):
assert kappa(*predictions) == pytest.approx(result, abs=1e-3)
assert kappa(predictions) == pytest.approx(result, abs=1e-3)
2 changes: 1 addition & 1 deletion tests/unit/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def test_validation(self):
)

with pytest.raises(ValueError):
fastrepl.LocalRunner(evaluator=eval, dataset=ds).run()
fastrepl.LocalRunner(evaluator=eval, dataset=ds)

0 comments on commit d6796d1

Please sign in to comment.