From 2f0d575538aa39da1d0e749a8bdabcc1b5d6296b Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 18 Oct 2021 11:31:21 -0400 Subject: [PATCH] fix nightly build breaking due to new scikit-learn package which breaks model function serialization (#455) --- .../common/model_wrapper.py | 53 +++++++++++++++---- test/test_serialize_explainer.py | 4 +- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/python/interpret_community/common/model_wrapper.py b/python/interpret_community/common/model_wrapper.py index e3ec528b..f6c812c1 100644 --- a/python/interpret_community/common/model_wrapper.py +++ b/python/interpret_community/common/model_wrapper.py @@ -191,13 +191,49 @@ def predict_proba(self, dataset): return self.predict(dataset) -class WrappedClassificationModel(object): - """A class for wrapping a classification model.""" +class BaseWrappedModel(object): + """A base class for WrappedClassificationModel and WrappedRegressionModel.""" - def __init__(self, model, eval_function): + def __init__(self, model, eval_function, examples, model_task): """Initialize the WrappedClassificationModel with the model and evaluation function.""" self._eval_function = eval_function self._model = model + self._examples = examples + self._model_task = model_task + + def __getstate__(self): + """Influence how BaseWrappedModel is pickled. + + Removes _eval_function which may not be serializable. + + :return state: The state to be pickled, with _eval_function removed. + :rtype: dict + """ + odict = self.__dict__.copy() + if self._examples is not None: + del odict['_eval_function'] + return odict + + def __setstate__(self, state): + """Influence how BaseWrappedModel is unpickled. + + Re-adds _eval_function which may not be serializable. + + :param dict: A dictionary of deserialized state. + :type dict: dict + """ + self.__dict__.update(state) + if self._examples is not None: + eval_function, _ = _eval_model(self._model, self._examples, self._model_task) + self._eval_function = eval_function + + +class WrappedClassificationModel(BaseWrappedModel): + """A class for wrapping a classification model.""" + + def __init__(self, model, eval_function, examples=None): + """Initialize the WrappedClassificationModel with the model and evaluation function.""" + super(WrappedClassificationModel, self).__init__(model, eval_function, examples, ModelTask.Classification) def predict(self, dataset): """Predict the output using the wrapped classification model. @@ -235,13 +271,12 @@ def predict_proba(self, dataset): return proba_preds -class WrappedRegressionModel(object): +class WrappedRegressionModel(BaseWrappedModel): """A class for wrapping a regression model.""" - def __init__(self, model, eval_function): + def __init__(self, model, eval_function, examples=None): """Initialize the WrappedRegressionModel with the model and evaluation function.""" - self._eval_function = eval_function - self._model = model + super(WrappedRegressionModel, self).__init__(model, eval_function, examples, ModelTask.Regression) def predict(self, dataset): """Predict the output using the wrapped regression model. @@ -342,9 +377,9 @@ def _wrap_model(model, examples, model_task, is_function): model = WrappedClassificationWithoutProbaModel(model) eval_function, eval_ml_domain = _eval_model(model, examples, model_task) if eval_ml_domain == ModelTask.Classification: - return WrappedClassificationModel(model, eval_function), eval_ml_domain + return WrappedClassificationModel(model, eval_function, examples), eval_ml_domain else: - return WrappedRegressionModel(model, eval_function), eval_ml_domain + return WrappedRegressionModel(model, eval_function, examples), eval_ml_domain def _classifier_without_proba(model): diff --git a/test/test_serialize_explainer.py b/test/test_serialize_explainer.py index e25bbdc2..70e9f6bc 100644 --- a/test/test_serialize_explainer.py +++ b/test/test_serialize_explainer.py @@ -47,7 +47,7 @@ def test_serialize_kernel(self): with open(explainer_name, 'wb') as stream: dump(explainer.explainer.explainer, stream) with open(model_name, 'wb') as stream: - dump(explainer.model.predict_proba, stream) + dump(explainer.model, stream) assert path.exists(model_name) assert path.exists(explainer_name) @@ -69,7 +69,7 @@ def test_serialize_mimic_lightgbm(self): surrogate_name = 'surrogate_model.joblib' tree_explainer_name = 'tree_explainer_model.joblib' with open(model_name, 'wb') as stream: - dump(explainer.model.predict_proba, stream) + dump(explainer.model, stream) with open(surrogate_name, 'wb') as stream: dump(explainer.surrogate_model.model, stream) with open(tree_explainer_name, 'wb') as stream: