Skip to content

Commit

Permalink
Merge pull request #54 from zwicker-group/choices
Browse files Browse the repository at this point in the history
Allow specifying choices for parameter values
  • Loading branch information
david-zwicker authored Jan 2, 2024
2 parents 32ffae4 + ac47f80 commit 131ba47
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 5 deletions.
12 changes: 11 additions & 1 deletion modelrunner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
get_args,
get_origin,
)

from .parameters import (
Expand Down Expand Up @@ -329,14 +332,21 @@ def make_model_class(func: Callable, *, default: bool = False) -> Type[ModelBase
# all remaining parameters are treated as model parameters
if param.annotation is param.empty:
cls = object
choices = None
elif get_origin(param.annotation) is Literal:
cls = object
choices = get_args(param.annotation)
else:
cls = param.annotation
choices = None
if param.default is param.empty:
default_value = NoValue
else:
default_value = param.default

parameter = Parameter(name, default_value=default_value, cls=cls)
parameter = Parameter(
name, default_value=default_value, cls=cls, choices=choices
)
parameters_default.append(parameter)

def __call__(self, *args, **kwargs):
Expand Down
23 changes: 20 additions & 3 deletions modelrunner/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union
from typing import Any, Callable, Container, Dict, Iterator, List, Optional, Type, Union

import numpy as np

Expand Down Expand Up @@ -87,8 +87,13 @@ class Parameter:
description (str):
A string describing the impact of this parameter. This
description appears in the parameter help
choices (container):
A list or set of values that the parameter can take. Values (including the
default value) that are not in this list will be rejected. Note that values
are check after they have been converted by `cls`, so specifying `cls` is
particularly important to convert command line parameters from strings.
required (bool):
Whther the parameter is required
Whether the parameter is required
hidden (bool):
Whether the parameter is hidden in the description summary
extra (dict):
Expand All @@ -99,10 +104,16 @@ class Parameter:
default_value: Any = None
cls: Union[Type, Callable] = object
description: str = ""
choices: Optional[Container] = None
required: bool = False
hidden: bool = False
extra: Dict[str, Any] = field(default_factory=dict)

def _check_value(self, value) -> None:
"""checks whether the value is acceptable"""
if value is not None and self.choices is not None and value not in self.choices:
raise ValueError(f"Default value `{value}` not in `{self.choices}`")

def __post_init__(self):
"""check default values and cls"""
if self.cls is not object and not any(
Expand All @@ -113,9 +124,12 @@ def __post_init__(self):
converted_value = self.cls(self.default_value)
except TypeError as err:
raise TypeError(
f"Parameter {self.name} has invalid default: {self.default_value}"
f"Parameter {self.name} of type {self.cls} has invalid default "
f"value: {self.default_value}"
) from err

self._check_value(converted_value)

if isinstance(converted_value, np.ndarray):
# numpy arrays are checked for each individual value
valid_default = np.allclose(
Expand Down Expand Up @@ -143,6 +157,7 @@ def __getstate__(self):
"default_value": self.convert(),
"cls": self.cls.__module__ + "." + self.cls.__name__,
"description": self.description,
"choices": self.choices,
"required": self.required,
"hidden": self.hidden,
"extra": self.extra,
Expand Down Expand Up @@ -195,6 +210,7 @@ def convert(self, value=NoValue, *, strict: bool = True):
) from err
# else: just return the value unchanged

self._check_value(value)
return value

def _argparser_add(self, parser):
Expand All @@ -208,6 +224,7 @@ def _argparser_add(self, parser):
arg_name = "--" + self.name
kwargs = {
"required": self.required,
"choices": self.choices,
"default": self.default_value,
"help": description,
}
Expand Down
42 changes: 41 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from pathlib import Path
from typing import List # @UnusedImport
from typing import Literal

import pytest

Expand Down Expand Up @@ -142,6 +142,27 @@ def __call__(self):
A.run_from_command_line([])


def test_choices_arguments_model_class():
"""test arguments with choices"""

class A(ModelBase):
parameters_default = [
Parameter("a", 1, cls=int, choices={1, 2, 3}),
Parameter("b", 2),
]

def __call__(self):
return self.parameters["a"] + self.parameters["b"]

assert A()() == 3
assert A({"a": 3})() == 5
with pytest.raises(ValueError):
A({"a": 4})
assert A.run_from_command_line(["--a", "3"]).data == 5
with pytest.raises(SystemExit):
A.run_from_command_line(["--a", "4"])


def test_make_model():
"""test the make_model decorator"""

Expand Down Expand Up @@ -185,6 +206,25 @@ def model_func(a=2):
assert model({"a": 4}).get_result().data == 16


def test_make_model_class_literal_args():
"""test the make_model_class function"""

def model_func(a: Literal["a", 2] = 2):
return a * 2

model = make_model_class(model_func)
assert model.parameters_default[0].choices == ("a", 2)

assert model()() == 4
assert model({"a": "a"})() == "aa"
with pytest.raises(ValueError):
model({"a": 3})

assert model.run_from_command_line(["--a", "a"]).data == "aa"
with pytest.raises(SystemExit):
model.run_from_command_line(["--a", "b"])


def test_argparse_boolean_arguments():
"""test boolean parameters"""

Expand Down
30 changes: 30 additions & 0 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,36 @@ class TestRequired(Parameterized):
TestRequired()


def test_parameter_choices():
"""test parameter with explicit choices"""

class TestChoices(Parameterized):
parameters_default = [Parameter("a", choices={1, 2, 3})]

assert TestChoices().parameters["a"] is None
assert TestChoices({"a": 2}).parameters["a"] == 2
with pytest.raises(ValueError):
TestChoices({"a": 0})
with pytest.raises(ValueError):
TestChoices({"a": 4})

class TestChoicesRequired(Parameterized):
parameters_default = [Parameter("a", required=True, choices={1, 2, 3})]

assert TestChoicesRequired({"a": 2}).parameters["a"] == 2
with pytest.raises(ValueError):
TestChoicesRequired({"a": 0})
with pytest.raises(ValueError):
TestChoicesRequired({"a": 4})
with pytest.raises(ValueError):
TestChoicesRequired()

with pytest.raises(ValueError):

class TestChoicesInconsistent(Parameterized):
parameters_default = [Parameter("a", 4, choices={1, 2, 3})]


def test_hidden_parameter():
"""test how hidden parameters are handled"""

Expand Down

0 comments on commit 131ba47

Please sign in to comment.