diff --git a/docs/getting_started/quickstart.ipynb b/docs/getting_started/quickstart.ipynb index d9a9de3..daf9cfe 100644 --- a/docs/getting_started/quickstart.ipynb +++ b/docs/getting_started/quickstart.ipynb @@ -477,8 +477,8 @@ } ], "source": [ - "first_result = [row[0] for row in result[\"results\"]]\n", - "second_result = [row[1] for row in result[\"results\"]]\n", + "first_result = [row[0] for row in result[\"result\"]]\n", + "second_result = [row[1] for row in result[\"result\"]]\n", "\n", "print(first_result.count(None))\n", "print(second_result.count(None))" diff --git a/fastrepl/runner.py b/fastrepl/runner.py index b35d713..e4965de 100644 --- a/fastrepl/runner.py +++ b/fastrepl/runner.py @@ -24,17 +24,26 @@ def __init__( evaluator: fastrepl.Evaluator, dataset: Dataset, output_feature="result", - output_feature_multiple="results", ) -> None: self._evaluator = evaluator self._dataset = dataset - self._output_feature = output_feature - self._output_feature_multiple = output_feature_multiple - self._input_features = [ param for param in inspect.signature(evaluator.run).parameters.keys() ] + self._output_feature = output_feature + + def _validate( + self, + evaluator: fastrepl.Evaluator, + dataset: Dataset, + ) -> None: + if any(feature not in dataset.column_names for feature in self._input_features): + eval_name = type(evaluator).__name__ + + raise ValueError( # TODO: custom error + f"{eval_name} requires {self._input_features}, but the provided dataset has {dataset.column_names}" + ) def _run_eval(self, **kwargs) -> Optional[Any]: return self._evaluator.run(**kwargs) @@ -62,24 +71,19 @@ def _run(self, progress: Progress, task_id: TaskID) -> List[Optional[Any]]: return results def run(self, num=1) -> Dataset: + self._validate(self._evaluator, self._dataset) + with Progress() as progress: msg = "[cyan]Processing..." task_id = progress.add_task(msg, total=len(self._dataset) * num) - if num == 1: - result = self._run(progress, task_id) - - return self._dataset.add_column( - self._output_feature, - result, - ) - else: + if num > 1: results = [self._run(progress, task_id) for _ in range(num)] + column = list(zip(*results)) + return self._dataset.add_column(self._output_feature, column) - return self._dataset.add_column( - self._output_feature_multiple, - list(zip(*results)), - ) + column = self._run(progress, task_id) + return self._dataset.add_column(self._output_feature, column) class LocalRunnerREPL(LocalRunner): diff --git a/tests/unit/test_runner.py b/tests/unit/test_runner.py index 33faf80..81c0ea8 100644 --- a/tests/unit/test_runner.py +++ b/tests/unit/test_runner.py @@ -1,5 +1,4 @@ import pytest -import warnings from datasets import Dataset import fastrepl @@ -19,40 +18,49 @@ def mock_run(*args, **kwargs): class TestLocalRunner: - def test_runner_num_1(self, mock_runs): + def test_num_1(self, mock_runs): mock_runs([[1]]) - ds = Dataset.from_dict({"input": [1]}) + ds = Dataset.from_dict({"sample": [1]}) eval = fastrepl.SimpleEvaluator( node=fastrepl.LLMClassificationHead(context="", labels={}) ) result = fastrepl.LocalRunner(evaluator=eval, dataset=ds).run(num=1) - assert result.column_names == ["input", "result"] + assert result.column_names == ["sample", "result"] - def test_runner_num_2(self, mock_runs): + def test_num_2(self, mock_runs): mock_runs([[1, 2, 3, 4], [1, 2, 3, 5]]) - ds = Dataset.from_dict({"input": [1, 2, 3, 4]}) + ds = Dataset.from_dict({"sample": [1, 2, 3, 4]}) eval = fastrepl.SimpleEvaluator( node=fastrepl.LLMClassificationHead(context="", labels={}) ) result = fastrepl.LocalRunner(evaluator=eval, dataset=ds).run(num=2) - assert result.column_names == ["input", "results"] - assert result["results"] == [[1, 1], [2, 2], [3, 3], [4, 5]] + assert result.column_names == ["sample", "result"] + assert result["result"] == [[1, 1], [2, 2], [3, 3], [4, 5]] - def test_runner_num_2_handle_none(self, mock_runs): + def test_num_2_handle_none(self, mock_runs): mock_runs([[1, 2, 3, 4], [1, 2, 3, None]]) - ds = Dataset.from_dict({"input": [1, 2, 3, 4]}) + ds = Dataset.from_dict({"sample": [1, 2, 3, 4]}) eval = fastrepl.SimpleEvaluator( node=fastrepl.LLMClassificationHead(context="", labels={}) ) result = fastrepl.LocalRunner(evaluator=eval, dataset=ds).run(num=2) - assert result.column_names == ["input", "results"] - assert result["results"] == [[1, 1], [2, 2], [3, 3], [4, None]] + assert result.column_names == ["sample", "result"] + assert result["result"] == [[1, 1], [2, 2], [3, 3], [4, None]] + + def test_validation(self): + ds = Dataset.from_dict({"input": [1, 2, 3, 4]}) + eval = fastrepl.SimpleEvaluator( + node=fastrepl.LLMClassificationHead(context="", labels={}) + ) + + with pytest.raises(ValueError): + fastrepl.LocalRunner(evaluator=eval, dataset=ds).run()