Skip to content

Commit

Permalink
Merge pull request #53 from zwicker-group/required
Browse files Browse the repository at this point in the history
Parameters can now be marked as required

Closes #6
  • Loading branch information
david-zwicker authored Jan 2, 2024
2 parents e69243e + d0f3212 commit 32ffae4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
16 changes: 13 additions & 3 deletions modelrunner/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ class Parameter:
name (str):
The name of the parameter
default_value:
The default value
The default value of the parameter
cls:
The type of the parameter, which is used for conversion
description (str):
A string describing the impact of this parameter. This
description appears in the parameter help
required (bool):
Whther the parameter is required
hidden (bool):
Whether the parameter is hidden in the description summary
extra (dict):
Expand All @@ -97,6 +99,7 @@ class Parameter:
default_value: Any = None
cls: Union[Type, Callable] = object
description: str = ""
required: bool = False
hidden: bool = False
extra: Dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -140,6 +143,7 @@ def __getstate__(self):
"default_value": self.convert(),
"cls": self.cls.__module__ + "." + self.cls.__name__,
"description": self.description,
"required": self.required,
"hidden": self.hidden,
"extra": self.extra,
}
Expand Down Expand Up @@ -202,7 +206,11 @@ def _argparser_add(self, parser):
description = f"Parameter `{self.name}`"

arg_name = "--" + self.name
kwargs = {"default": self.default_value, "help": description}
kwargs = {
"required": self.required,
"default": self.default_value,
"help": description,
}

if self.cls is bool:
# parameter is a boolean that we want to adjust
Expand Down Expand Up @@ -438,7 +446,7 @@ def _parse_parameters(
allow_hidden: bool = True,
include_deprecated: bool = False,
) -> Dict[str, Any]:
"""parse parameters
"""parse parameters from a given dictionary
Args:
parameters (dict):
Expand Down Expand Up @@ -470,6 +478,8 @@ def _parse_parameters(
for name, param_obj in param_objs.items():
if not allow_hidden and param_obj.hidden:
continue # skip hidden parameters
if param_obj.required and name not in parameters:
raise ValueError(f"Require parameter `{name}`")
# take value from parameters or set default value
value = parameters.pop(name, NoValue)
# convert parameter to correct type
Expand Down
17 changes: 16 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def req_args_3(a=None):
assert req_args_3() is None


def test_required_arguments_model_class():
def test_required_arguments_model_class_decorator():
"""test required arguments"""

@make_model_class
Expand All @@ -127,6 +127,21 @@ def required_arg_3(a=None):
assert required_arg_3()() is None


def test_required_arguments_model_class():
"""test required arguments"""

class A(ModelBase):
parameters_default = [Parameter("a", required=True), Parameter("b", 2)]

def __call__(self):
return self.parameters["a"] + self.parameters["b"]

assert A({"a": 3})() == 5
assert A.run_from_command_line(["--a", "3"]).data == 5
with pytest.raises(SystemExit):
A.run_from_command_line([])


def test_make_model():
"""test the make_model decorator"""

Expand Down
11 changes: 11 additions & 0 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ class TestHelp2(TestHelp1):
assert e1 == e2 == ""


def test_parameter_required():
"""test required parameter"""

class TestRequired(Parameterized):
parameters_default = [Parameter("a", required=True)]

assert TestRequired({"a": 2}).parameters["a"] == 2
with pytest.raises(ValueError):
TestRequired()


def test_hidden_parameter():
"""test how hidden parameters are handled"""

Expand Down

0 comments on commit 32ffae4

Please sign in to comment.