Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Add dataset validation in runner (#125)
Browse files Browse the repository at this point in the history
* Remove output_feature_multiple

* Add dataset validation

* fix
  • Loading branch information
yujonglee authored Sep 19, 2023
1 parent 5645bd6 commit 136821b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
4 changes: 2 additions & 2 deletions docs/getting_started/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@
}
],
"source": [
"first_result = [row[0] for row in result[\"results\"]]\n",
"second_result = [row[1] for row in result[\"results\"]]\n",
"first_result = [row[0] for row in result[\"result\"]]\n",
"second_result = [row[1] for row in result[\"result\"]]\n",
"\n",
"print(first_result.count(None))\n",
"print(second_result.count(None))"
Expand Down
36 changes: 20 additions & 16 deletions fastrepl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,26 @@ def __init__(
evaluator: fastrepl.Evaluator,
dataset: Dataset,
output_feature="result",
output_feature_multiple="results",
) -> None:
self._evaluator = evaluator
self._dataset = dataset

self._output_feature = output_feature
self._output_feature_multiple = output_feature_multiple

self._input_features = [
param for param in inspect.signature(evaluator.run).parameters.keys()
]
self._output_feature = output_feature

def _validate(
self,
evaluator: fastrepl.Evaluator,
dataset: Dataset,
) -> None:
if any(feature not in dataset.column_names for feature in self._input_features):
eval_name = type(evaluator).__name__

raise ValueError( # TODO: custom error
f"{eval_name} requires {self._input_features}, but the provided dataset has {dataset.column_names}"
)

def _run_eval(self, **kwargs) -> Optional[Any]:
return self._evaluator.run(**kwargs)
Expand Down Expand Up @@ -62,24 +71,19 @@ def _run(self, progress: Progress, task_id: TaskID) -> List[Optional[Any]]:
return results

def run(self, num=1) -> Dataset:
self._validate(self._evaluator, self._dataset)

with Progress() as progress:
msg = "[cyan]Processing..."
task_id = progress.add_task(msg, total=len(self._dataset) * num)

if num == 1:
result = self._run(progress, task_id)

return self._dataset.add_column(
self._output_feature,
result,
)
else:
if num > 1:
results = [self._run(progress, task_id) for _ in range(num)]
column = list(zip(*results))
return self._dataset.add_column(self._output_feature, column)

return self._dataset.add_column(
self._output_feature_multiple,
list(zip(*results)),
)
column = self._run(progress, task_id)
return self._dataset.add_column(self._output_feature, column)


class LocalRunnerREPL(LocalRunner):
Expand Down
32 changes: 20 additions & 12 deletions tests/unit/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import warnings
from datasets import Dataset

import fastrepl
Expand All @@ -19,40 +18,49 @@ def mock_run(*args, **kwargs):


class TestLocalRunner:
def test_runner_num_1(self, mock_runs):
def test_num_1(self, mock_runs):
mock_runs([[1]])

ds = Dataset.from_dict({"input": [1]})
ds = Dataset.from_dict({"sample": [1]})
eval = fastrepl.SimpleEvaluator(
node=fastrepl.LLMClassificationHead(context="", labels={})
)

result = fastrepl.LocalRunner(evaluator=eval, dataset=ds).run(num=1)

assert result.column_names == ["input", "result"]
assert result.column_names == ["sample", "result"]

def test_runner_num_2(self, mock_runs):
def test_num_2(self, mock_runs):
mock_runs([[1, 2, 3, 4], [1, 2, 3, 5]])

ds = Dataset.from_dict({"input": [1, 2, 3, 4]})
ds = Dataset.from_dict({"sample": [1, 2, 3, 4]})
eval = fastrepl.SimpleEvaluator(
node=fastrepl.LLMClassificationHead(context="", labels={})
)

result = fastrepl.LocalRunner(evaluator=eval, dataset=ds).run(num=2)

assert result.column_names == ["input", "results"]
assert result["results"] == [[1, 1], [2, 2], [3, 3], [4, 5]]
assert result.column_names == ["sample", "result"]
assert result["result"] == [[1, 1], [2, 2], [3, 3], [4, 5]]

def test_runner_num_2_handle_none(self, mock_runs):
def test_num_2_handle_none(self, mock_runs):
mock_runs([[1, 2, 3, 4], [1, 2, 3, None]])

ds = Dataset.from_dict({"input": [1, 2, 3, 4]})
ds = Dataset.from_dict({"sample": [1, 2, 3, 4]})
eval = fastrepl.SimpleEvaluator(
node=fastrepl.LLMClassificationHead(context="", labels={})
)

result = fastrepl.LocalRunner(evaluator=eval, dataset=ds).run(num=2)

assert result.column_names == ["input", "results"]
assert result["results"] == [[1, 1], [2, 2], [3, 3], [4, None]]
assert result.column_names == ["sample", "result"]
assert result["result"] == [[1, 1], [2, 2], [3, 3], [4, None]]

def test_validation(self):
ds = Dataset.from_dict({"input": [1, 2, 3, 4]})
eval = fastrepl.SimpleEvaluator(
node=fastrepl.LLMClassificationHead(context="", labels={})
)

with pytest.raises(ValueError):
fastrepl.LocalRunner(evaluator=eval, dataset=ds).run()

0 comments on commit 136821b

Please sign in to comment.