Skip to content

Commit

Permalink
Allow specifying choices for parameters
Browse files Browse the repository at this point in the history
Closes #16
  • Loading branch information
david-zwicker committed Jan 2, 2024
1 parent 8cca94d commit ac47f80
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
12 changes: 11 additions & 1 deletion modelrunner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
get_args,
get_origin,
)

from .parameters import (
Expand Down Expand Up @@ -329,14 +332,21 @@ def make_model_class(func: Callable, *, default: bool = False) -> Type[ModelBase
# all remaining parameters are treated as model parameters
if param.annotation is param.empty:
cls = object
choices = None
elif get_origin(param.annotation) is Literal:
cls = object
choices = get_args(param.annotation)
else:
cls = param.annotation
choices = None
if param.default is param.empty:
default_value = NoValue
else:
default_value = param.default

parameter = Parameter(name, default_value=default_value, cls=cls)
parameter = Parameter(
name, default_value=default_value, cls=cls, choices=choices
)
parameters_default.append(parameter)

def __call__(self, *args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion modelrunner/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def __post_init__(self):
converted_value = self.cls(self.default_value)
except TypeError as err:
raise TypeError(
f"Parameter {self.name} has invalid default: {self.default_value}"
f"Parameter {self.name} of type {self.cls} has invalid default "
f"value: {self.default_value}"
) from err

self._check_value(converted_value)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from pathlib import Path
from typing import List # @UnusedImport
from typing import Literal

import pytest

Expand Down Expand Up @@ -206,6 +206,25 @@ def model_func(a=2):
assert model({"a": 4}).get_result().data == 16


def test_make_model_class_literal_args():
"""test the make_model_class function"""

def model_func(a: Literal["a", 2] = 2):
return a * 2

model = make_model_class(model_func)
assert model.parameters_default[0].choices == ("a", 2)

assert model()() == 4
assert model({"a": "a"})() == "aa"
with pytest.raises(ValueError):
model({"a": 3})

assert model.run_from_command_line(["--a", "a"]).data == "aa"
with pytest.raises(SystemExit):
model.run_from_command_line(["--a", "b"])


def test_argparse_boolean_arguments():
"""test boolean parameters"""

Expand Down

0 comments on commit ac47f80

Please sign in to comment.