diff --git a/modelrunner/model.py b/modelrunner/model.py index 69e7ca5..21226f9 100644 --- a/modelrunner/model.py +++ b/modelrunner/model.py @@ -18,11 +18,14 @@ Any, Callable, Dict, + Literal, Optional, Sequence, Type, TypeVar, Union, + get_args, + get_origin, ) from .parameters import ( @@ -329,14 +332,21 @@ def make_model_class(func: Callable, *, default: bool = False) -> Type[ModelBase # all remaining parameters are treated as model parameters if param.annotation is param.empty: cls = object + choices = None + elif get_origin(param.annotation) is Literal: + cls = object + choices = get_args(param.annotation) else: cls = param.annotation + choices = None if param.default is param.empty: default_value = NoValue else: default_value = param.default - parameter = Parameter(name, default_value=default_value, cls=cls) + parameter = Parameter( + name, default_value=default_value, cls=cls, choices=choices + ) parameters_default.append(parameter) def __call__(self, *args, **kwargs): diff --git a/modelrunner/parameters.py b/modelrunner/parameters.py index d992c01..7c44781 100644 --- a/modelrunner/parameters.py +++ b/modelrunner/parameters.py @@ -124,7 +124,8 @@ def __post_init__(self): converted_value = self.cls(self.default_value) except TypeError as err: raise TypeError( - f"Parameter {self.name} has invalid default: {self.default_value}" + f"Parameter {self.name} of type {self.cls} has invalid default " + f"value: {self.default_value}" ) from err self._check_value(converted_value) diff --git a/tests/test_model.py b/tests/test_model.py index 363d1f1..e22133d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ """ from pathlib import Path -from typing import List # @UnusedImport +from typing import Literal import pytest @@ -206,6 +206,25 @@ def model_func(a=2): assert model({"a": 4}).get_result().data == 16 +def test_make_model_class_literal_args(): + """test the make_model_class function""" + + def model_func(a: Literal["a", 2] = 2): + return a * 2 + + model = make_model_class(model_func) + assert model.parameters_default[0].choices == ("a", 2) + + assert model()() == 4 + assert model({"a": "a"})() == "aa" + with pytest.raises(ValueError): + model({"a": 3}) + + assert model.run_from_command_line(["--a", "a"]).data == "aa" + with pytest.raises(SystemExit): + model.run_from_command_line(["--a", "b"]) + + def test_argparse_boolean_arguments(): """test boolean parameters"""