From 8cca94dfa3b37fceead19e9b96291ac1555723a5 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Tue, 2 Jan 2024 17:43:01 +0100 Subject: [PATCH] Allow specifying choices for parameter values Closes #16 --- modelrunner/parameters.py | 20 ++++++++++++++++++-- tests/test_model.py | 21 +++++++++++++++++++++ tests/test_parameters.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/modelrunner/parameters.py b/modelrunner/parameters.py index 6297733..d992c01 100644 --- a/modelrunner/parameters.py +++ b/modelrunner/parameters.py @@ -21,7 +21,7 @@ import logging import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union +from typing import Any, Callable, Container, Dict, Iterator, List, Optional, Type, Union import numpy as np @@ -87,8 +87,13 @@ class Parameter: description (str): A string describing the impact of this parameter. This description appears in the parameter help + choices (container): + A list or set of values that the parameter can take. Values (including the + default value) that are not in this list will be rejected. Note that values + are check after they have been converted by `cls`, so specifying `cls` is + particularly important to convert command line parameters from strings. required (bool): - Whther the parameter is required + Whether the parameter is required hidden (bool): Whether the parameter is hidden in the description summary extra (dict): @@ -99,10 +104,16 @@ class Parameter: default_value: Any = None cls: Union[Type, Callable] = object description: str = "" + choices: Optional[Container] = None required: bool = False hidden: bool = False extra: Dict[str, Any] = field(default_factory=dict) + def _check_value(self, value) -> None: + """checks whether the value is acceptable""" + if value is not None and self.choices is not None and value not in self.choices: + raise ValueError(f"Default value `{value}` not in `{self.choices}`") + def __post_init__(self): """check default values and cls""" if self.cls is not object and not any( @@ -116,6 +127,8 @@ def __post_init__(self): f"Parameter {self.name} has invalid default: {self.default_value}" ) from err + self._check_value(converted_value) + if isinstance(converted_value, np.ndarray): # numpy arrays are checked for each individual value valid_default = np.allclose( @@ -143,6 +156,7 @@ def __getstate__(self): "default_value": self.convert(), "cls": self.cls.__module__ + "." + self.cls.__name__, "description": self.description, + "choices": self.choices, "required": self.required, "hidden": self.hidden, "extra": self.extra, @@ -195,6 +209,7 @@ def convert(self, value=NoValue, *, strict: bool = True): ) from err # else: just return the value unchanged + self._check_value(value) return value def _argparser_add(self, parser): @@ -208,6 +223,7 @@ def _argparser_add(self, parser): arg_name = "--" + self.name kwargs = { "required": self.required, + "choices": self.choices, "default": self.default_value, "help": description, } diff --git a/tests/test_model.py b/tests/test_model.py index 5732153..363d1f1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -142,6 +142,27 @@ def __call__(self): A.run_from_command_line([]) +def test_choices_arguments_model_class(): + """test arguments with choices""" + + class A(ModelBase): + parameters_default = [ + Parameter("a", 1, cls=int, choices={1, 2, 3}), + Parameter("b", 2), + ] + + def __call__(self): + return self.parameters["a"] + self.parameters["b"] + + assert A()() == 3 + assert A({"a": 3})() == 5 + with pytest.raises(ValueError): + A({"a": 4}) + assert A.run_from_command_line(["--a", "3"]).data == 5 + with pytest.raises(SystemExit): + A.run_from_command_line(["--a", "4"]) + + def test_make_model(): """test the make_model decorator""" diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 0b0e4f2..6a514f9 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -157,6 +157,36 @@ class TestRequired(Parameterized): TestRequired() +def test_parameter_choices(): + """test parameter with explicit choices""" + + class TestChoices(Parameterized): + parameters_default = [Parameter("a", choices={1, 2, 3})] + + assert TestChoices().parameters["a"] is None + assert TestChoices({"a": 2}).parameters["a"] == 2 + with pytest.raises(ValueError): + TestChoices({"a": 0}) + with pytest.raises(ValueError): + TestChoices({"a": 4}) + + class TestChoicesRequired(Parameterized): + parameters_default = [Parameter("a", required=True, choices={1, 2, 3})] + + assert TestChoicesRequired({"a": 2}).parameters["a"] == 2 + with pytest.raises(ValueError): + TestChoicesRequired({"a": 0}) + with pytest.raises(ValueError): + TestChoicesRequired({"a": 4}) + with pytest.raises(ValueError): + TestChoicesRequired() + + with pytest.raises(ValueError): + + class TestChoicesInconsistent(Parameterized): + parameters_default = [Parameter("a", 4, choices={1, 2, 3})] + + def test_hidden_parameter(): """test how hidden parameters are handled"""