Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
model: scikit: Use auto args and config
Browse files Browse the repository at this point in the history
Fixes: #285

Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
pdxjohnny committed Jan 3, 2020
1 parent 2c606bf commit 6415473
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Function to create a config class dynamically, analogous to `make_dataclass`
### Changed
- CLI tests and integration tests derive from `AsyncExitStackTestCase`
- SciKit models now use the auto args and config methods.
### Fixed
- Correctly identify when functions decorated with `op` use `self` to reference
the `OperationImplementationContext`.
Expand Down
232 changes: 134 additions & 98 deletions model/scikit/dffml_model_scikit/scikit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import sys
import ast
import inspect
import dataclasses
from collections import namedtuple
from typing import Dict
from typing import Dict, Optional, Tuple, Type, Any

from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
Expand Down Expand Up @@ -43,6 +44,7 @@
Ridge,
)

from dffml.base import make_config, field
from dffml.util.cli.arg import Arg
from dffml.util.entrypoint import entry_point
from dffml_model_scikit.scikit_base import Scikit, ScikitContext
Expand Down Expand Up @@ -70,6 +72,115 @@ class NoDefaultValue:
pass


class ParameterNotInDocString(Exception):
"""
Raised when a scikit class has a parameter in its ``__init__`` which was not
present in it's docstring. Therefore we have no typing information for it.
"""


def scikit_get_default(type_str):
if not "default" in type_str:
return dataclasses.MISSING
type_str = type_str[type_str.index("default") :]
type_str = type_str.replace("default", "")
type_str = type_str.replace(")", "")
type_str = type_str.replace("=", "")
type_str = type_str.replace('"', "")
type_str = type_str.replace("'", "")
type_str = type_str.strip()
if type_str == "None":
return None
return type_str


SCIKIT_DOCS_TYPE_MAP = {
"int": int,
"integer": int,
"str": str,
"string": str,
"float": float,
"dict": dict,
"bool": bool,
}


def scikit_doc_to_field(type_str, param):
default = param.default
if default is inspect.Parameter.empty:
default = scikit_get_default(type_str)

type_cls = Any

# Set of choices
if "{'" in type_str and "'}" in type_str:
type_cls = str
elif "{" in type_str and "}" in type_str:
type_cls = int
if "." in type_str:
type_cls = float
else:
type_split = list(
map(lambda x: x.lower(), type_str.replace(",", "").split())
)
for scikit_type_name, python_type in SCIKIT_DOCS_TYPE_MAP.items():
if scikit_type_name in type_split:
type_cls = python_type

if type_cls == Any and default != None:
type_cls = type(default)

return type_cls, field(type_str, default=default)


def mkscikit_config_cls(
name: str,
cls: Type,
properties: Optional[Dict[str, Tuple[Type, field]]] = None,
):
"""
Given a scikit class, read its docstring and ``__init__`` parameters to
generate a config class with properties containing the correct types,
and default values.
"""
if properties is None:
properties = {}

parameters = inspect.signature(cls).parameters
docstring = inspect.getdoc(cls)

docparams = {}

# Parse parameters and their datatypes from docstring
last_param_name = None
for line in docstring.split("\n"):
if not ":" in line:
continue
param_name, dtypes = line.split(":", maxsplit=1)
param_name = param_name.strip()
dtypes = dtypes.strip()
if not param_name in parameters or param_name in docparams:
continue
docparams[param_name] = dtypes
last_param_name = param_name

# Ensure all required parameters are present in docstring
for param_name, param in parameters.items():
if param_name in ["args", "kwargs"]:
continue
if not param_name in docparams:
raise ParameterNotInDocString(
f"{param_name} for {cls.__qualname__}"
)
properties[param_name] = scikit_doc_to_field(
docparams[param_name], param
)

return make_config(
name, [tuple([key] + list(value)) for key, value in properties.items()]
)


for entry_point_name, name, cls, applicable_features_function in [
(
"scikitknn",
Expand Down Expand Up @@ -129,15 +240,10 @@ class NoDefaultValue:
ExtraTreesClassifier,
applicable_features,
),
(
"scikitbgc",
"BaggingClassifier",
BaggingClassifier,
applicable_features,
),
("scikiteln", "ElasticNet", ElasticNet, applicable_features,),
("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features,),
("scikitlas", "Lasso", Lasso, applicable_features,),
("scikitbgc", "BaggingClassifier", BaggingClassifier, applicable_features),
("scikiteln", "ElasticNet", ElasticNet, applicable_features),
("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features),
("scikitlas", "Lasso", Lasso, applicable_features),
("scikitard", "ARDRegression", ARDRegression, applicable_features),
("scikitrsc", "RANSACRegressor", RANSACRegressor, applicable_features),
("scikitbnb", "BernoulliNB", BernoulliNB, applicable_features),
Expand Down Expand Up @@ -170,95 +276,26 @@ class NoDefaultValue:
("scikitlars", "Lars", Lars, applicable_features),
]:

parameters = inspect.signature(cls).parameters
defaults = [
os.path.join(
os.path.expanduser("~"),
".cache",
"dffml",
f"scikit-{entry_point_name}",
),
NoDefaultValue,
] + [
param.default
for name, param in parameters.items()
if param.default != inspect._empty
]
dffml_config = namedtuple(
dffml_config = mkscikit_config_cls(
name + "ModelConfig",
["directory", "predict", "features"]
+ [
param.name
for _, param in parameters.items()
if param.default != inspect._empty
],
defaults=defaults,
)

setattr(sys.modules[__name__], dffml_config.__qualname__, dffml_config)

@classmethod
def args(cls, args, *above) -> Dict[str, Arg]:
cls.config_set(
args,
above,
"directory",
Arg(
default=os.path.join(
os.path.expanduser("~"),
".cache",
"dffml",
f"scikit-{entry_point_name}",
cls,
properties={
"directory": (
str,
field(
"Directory where state should be saved",
default=os.path.join(
os.path.expanduser("~"),
".cache",
"dffml",
f"scikit-{entry_point_name}",
),
),
help="Directory where state should be saved",
),
)
cls.config_set(
args,
above,
"predict",
Arg(type=str, help="Label or the value to be predicted"),
)

cls.config_set(
args,
above,
"features",
Arg(
nargs="+",
required=True,
type=Feature.load,
action=list_action(Features),
help="Features to train on",
),
)

for param in inspect.signature(cls.SCIKIT_MODEL).parameters.values():
# TODO if param.default is an array then Args needs to get a
# nargs="+"
cls.config_set(
args,
above,
param.name,
Arg(
type=cls.type_for(param),
default=NoDefaultValue
if param.default == inspect._empty
else param.default,
),
)
return args

@classmethod
def config(cls, config, *above):
params = dict(
directory=cls.config_get(config, above, "directory"),
predict=cls.config_get(config, above, "predict"),
features=cls.config_get(config, above, "features"),
)
for name in inspect.signature(cls.SCIKIT_MODEL).parameters.keys():
params[name] = cls.config_get(config, above, name)
return cls.CONFIG(**params)
"predict": (str, field("Label or the value to be predicted")),
"features": (Features, field("Features to train on")),
},
)

dffml_cls_ctx = type(
name + "ModelContext",
Expand All @@ -273,12 +310,11 @@ def config(cls, config, *above):
"CONFIG": dffml_config,
"CONTEXT": dffml_cls_ctx,
"SCIKIT_MODEL": cls,
"args": args,
"config": config,
},
)
# Add the ENTRY_POINT_ORIG_LABEL
dffml_cls = entry_point(entry_point_name)(dffml_cls)

setattr(sys.modules[__name__], dffml_config.__qualname__, dffml_config)
setattr(sys.modules[__name__], dffml_cls_ctx.__qualname__, dffml_cls_ctx)
setattr(sys.modules[__name__], dffml_cls.__qualname__, dffml_cls)

0 comments on commit 6415473

Please sign in to comment.