From d0f32121f573a69084fdcbc8055d230b2cbca417 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Tue, 2 Jan 2024 16:25:08 +0100 Subject: [PATCH] Parameters can now be marked as required Closes #6 --- modelrunner/parameters.py | 16 +++++++++++++--- tests/test_model.py | 17 ++++++++++++++++- tests/test_parameters.py | 11 +++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/modelrunner/parameters.py b/modelrunner/parameters.py index 1d94c60..6297733 100644 --- a/modelrunner/parameters.py +++ b/modelrunner/parameters.py @@ -81,12 +81,14 @@ class Parameter: name (str): The name of the parameter default_value: - The default value + The default value of the parameter cls: The type of the parameter, which is used for conversion description (str): A string describing the impact of this parameter. This description appears in the parameter help + required (bool): + Whther the parameter is required hidden (bool): Whether the parameter is hidden in the description summary extra (dict): @@ -97,6 +99,7 @@ class Parameter: default_value: Any = None cls: Union[Type, Callable] = object description: str = "" + required: bool = False hidden: bool = False extra: Dict[str, Any] = field(default_factory=dict) @@ -140,6 +143,7 @@ def __getstate__(self): "default_value": self.convert(), "cls": self.cls.__module__ + "." + self.cls.__name__, "description": self.description, + "required": self.required, "hidden": self.hidden, "extra": self.extra, } @@ -202,7 +206,11 @@ def _argparser_add(self, parser): description = f"Parameter `{self.name}`" arg_name = "--" + self.name - kwargs = {"default": self.default_value, "help": description} + kwargs = { + "required": self.required, + "default": self.default_value, + "help": description, + } if self.cls is bool: # parameter is a boolean that we want to adjust @@ -438,7 +446,7 @@ def _parse_parameters( allow_hidden: bool = True, include_deprecated: bool = False, ) -> Dict[str, Any]: - """parse parameters + """parse parameters from a given dictionary Args: parameters (dict): @@ -470,6 +478,8 @@ def _parse_parameters( for name, param_obj in param_objs.items(): if not allow_hidden and param_obj.hidden: continue # skip hidden parameters + if param_obj.required and name not in parameters: + raise ValueError(f"Require parameter `{name}`") # take value from parameters or set default value value = parameters.pop(name, NoValue) # convert parameter to correct type diff --git a/tests/test_model.py b/tests/test_model.py index 2ff9f65..5732153 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -101,7 +101,7 @@ def req_args_3(a=None): assert req_args_3() is None -def test_required_arguments_model_class(): +def test_required_arguments_model_class_decorator(): """test required arguments""" @make_model_class @@ -127,6 +127,21 @@ def required_arg_3(a=None): assert required_arg_3()() is None +def test_required_arguments_model_class(): + """test required arguments""" + + class A(ModelBase): + parameters_default = [Parameter("a", required=True), Parameter("b", 2)] + + def __call__(self): + return self.parameters["a"] + self.parameters["b"] + + assert A({"a": 3})() == 5 + assert A.run_from_command_line(["--a", "3"]).data == 5 + with pytest.raises(SystemExit): + A.run_from_command_line([]) + + def test_make_model(): """test the make_model decorator""" diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 047d849..0b0e4f2 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -146,6 +146,17 @@ class TestHelp2(TestHelp1): assert e1 == e2 == "" +def test_parameter_required(): + """test required parameter""" + + class TestRequired(Parameterized): + parameters_default = [Parameter("a", required=True)] + + assert TestRequired({"a": 2}).parameters["a"] == 2 + with pytest.raises(ValueError): + TestRequired() + + def test_hidden_parameter(): """test how hidden parameters are handled"""