Skip to content

Commit

Permalink
Allow specifying choices for parameter values
Browse files Browse the repository at this point in the history
Closes #16
  • Loading branch information
david-zwicker committed Jan 2, 2024
1 parent 32ffae4 commit 8cca94d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
20 changes: 18 additions & 2 deletions modelrunner/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union
from typing import Any, Callable, Container, Dict, Iterator, List, Optional, Type, Union

import numpy as np

Expand Down Expand Up @@ -87,8 +87,13 @@ class Parameter:
description (str):
A string describing the impact of this parameter. This
description appears in the parameter help
choices (container):
A list or set of values that the parameter can take. Values (including the
default value) that are not in this list will be rejected. Note that values
are check after they have been converted by `cls`, so specifying `cls` is
particularly important to convert command line parameters from strings.
required (bool):
Whther the parameter is required
Whether the parameter is required
hidden (bool):
Whether the parameter is hidden in the description summary
extra (dict):
Expand All @@ -99,10 +104,16 @@ class Parameter:
default_value: Any = None
cls: Union[Type, Callable] = object
description: str = ""
choices: Optional[Container] = None
required: bool = False
hidden: bool = False
extra: Dict[str, Any] = field(default_factory=dict)

def _check_value(self, value) -> None:
"""checks whether the value is acceptable"""
if value is not None and self.choices is not None and value not in self.choices:
raise ValueError(f"Default value `{value}` not in `{self.choices}`")

def __post_init__(self):
"""check default values and cls"""
if self.cls is not object and not any(
Expand All @@ -116,6 +127,8 @@ def __post_init__(self):
f"Parameter {self.name} has invalid default: {self.default_value}"
) from err

self._check_value(converted_value)

if isinstance(converted_value, np.ndarray):
# numpy arrays are checked for each individual value
valid_default = np.allclose(
Expand Down Expand Up @@ -143,6 +156,7 @@ def __getstate__(self):
"default_value": self.convert(),
"cls": self.cls.__module__ + "." + self.cls.__name__,
"description": self.description,
"choices": self.choices,
"required": self.required,
"hidden": self.hidden,
"extra": self.extra,
Expand Down Expand Up @@ -195,6 +209,7 @@ def convert(self, value=NoValue, *, strict: bool = True):
) from err
# else: just return the value unchanged

self._check_value(value)
return value

def _argparser_add(self, parser):
Expand All @@ -208,6 +223,7 @@ def _argparser_add(self, parser):
arg_name = "--" + self.name
kwargs = {
"required": self.required,
"choices": self.choices,
"default": self.default_value,
"help": description,
}
Expand Down
21 changes: 21 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,27 @@ def __call__(self):
A.run_from_command_line([])


def test_choices_arguments_model_class():
"""test arguments with choices"""

class A(ModelBase):
parameters_default = [
Parameter("a", 1, cls=int, choices={1, 2, 3}),
Parameter("b", 2),
]

def __call__(self):
return self.parameters["a"] + self.parameters["b"]

assert A()() == 3
assert A({"a": 3})() == 5
with pytest.raises(ValueError):
A({"a": 4})
assert A.run_from_command_line(["--a", "3"]).data == 5
with pytest.raises(SystemExit):
A.run_from_command_line(["--a", "4"])


def test_make_model():
"""test the make_model decorator"""

Expand Down
30 changes: 30 additions & 0 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,36 @@ class TestRequired(Parameterized):
TestRequired()


def test_parameter_choices():
"""test parameter with explicit choices"""

class TestChoices(Parameterized):
parameters_default = [Parameter("a", choices={1, 2, 3})]

assert TestChoices().parameters["a"] is None
assert TestChoices({"a": 2}).parameters["a"] == 2
with pytest.raises(ValueError):
TestChoices({"a": 0})
with pytest.raises(ValueError):
TestChoices({"a": 4})

class TestChoicesRequired(Parameterized):
parameters_default = [Parameter("a", required=True, choices={1, 2, 3})]

assert TestChoicesRequired({"a": 2}).parameters["a"] == 2
with pytest.raises(ValueError):
TestChoicesRequired({"a": 0})
with pytest.raises(ValueError):
TestChoicesRequired({"a": 4})
with pytest.raises(ValueError):
TestChoicesRequired()

with pytest.raises(ValueError):

class TestChoicesInconsistent(Parameterized):
parameters_default = [Parameter("a", 4, choices={1, 2, 3})]


def test_hidden_parameter():
"""test how hidden parameters are handled"""

Expand Down

0 comments on commit 8cca94d

Please sign in to comment.