Skip to content

Commit

Permalink
fix nightly build breaking due to new scikit-learn package which brea…
Browse files Browse the repository at this point in the history
…ks model function serialization (#455)
  • Loading branch information
imatiach-msft authored Oct 18, 2021
1 parent c8ce475 commit 2f0d575
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
53 changes: 44 additions & 9 deletions python/interpret_community/common/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,49 @@ def predict_proba(self, dataset):
return self.predict(dataset)


class WrappedClassificationModel(object):
"""A class for wrapping a classification model."""
class BaseWrappedModel(object):
"""A base class for WrappedClassificationModel and WrappedRegressionModel."""

def __init__(self, model, eval_function):
def __init__(self, model, eval_function, examples, model_task):
"""Initialize the WrappedClassificationModel with the model and evaluation function."""
self._eval_function = eval_function
self._model = model
self._examples = examples
self._model_task = model_task

def __getstate__(self):
"""Influence how BaseWrappedModel is pickled.
Removes _eval_function which may not be serializable.
:return state: The state to be pickled, with _eval_function removed.
:rtype: dict
"""
odict = self.__dict__.copy()
if self._examples is not None:
del odict['_eval_function']
return odict

def __setstate__(self, state):
"""Influence how BaseWrappedModel is unpickled.
Re-adds _eval_function which may not be serializable.
:param dict: A dictionary of deserialized state.
:type dict: dict
"""
self.__dict__.update(state)
if self._examples is not None:
eval_function, _ = _eval_model(self._model, self._examples, self._model_task)
self._eval_function = eval_function


class WrappedClassificationModel(BaseWrappedModel):
"""A class for wrapping a classification model."""

def __init__(self, model, eval_function, examples=None):
"""Initialize the WrappedClassificationModel with the model and evaluation function."""
super(WrappedClassificationModel, self).__init__(model, eval_function, examples, ModelTask.Classification)

def predict(self, dataset):
"""Predict the output using the wrapped classification model.
Expand Down Expand Up @@ -235,13 +271,12 @@ def predict_proba(self, dataset):
return proba_preds


class WrappedRegressionModel(object):
class WrappedRegressionModel(BaseWrappedModel):
"""A class for wrapping a regression model."""

def __init__(self, model, eval_function):
def __init__(self, model, eval_function, examples=None):
"""Initialize the WrappedRegressionModel with the model and evaluation function."""
self._eval_function = eval_function
self._model = model
super(WrappedRegressionModel, self).__init__(model, eval_function, examples, ModelTask.Regression)

def predict(self, dataset):
"""Predict the output using the wrapped regression model.
Expand Down Expand Up @@ -342,9 +377,9 @@ def _wrap_model(model, examples, model_task, is_function):
model = WrappedClassificationWithoutProbaModel(model)
eval_function, eval_ml_domain = _eval_model(model, examples, model_task)
if eval_ml_domain == ModelTask.Classification:
return WrappedClassificationModel(model, eval_function), eval_ml_domain
return WrappedClassificationModel(model, eval_function, examples), eval_ml_domain
else:
return WrappedRegressionModel(model, eval_function), eval_ml_domain
return WrappedRegressionModel(model, eval_function, examples), eval_ml_domain


def _classifier_without_proba(model):
Expand Down
4 changes: 2 additions & 2 deletions test/test_serialize_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_serialize_kernel(self):
with open(explainer_name, 'wb') as stream:
dump(explainer.explainer.explainer, stream)
with open(model_name, 'wb') as stream:
dump(explainer.model.predict_proba, stream)
dump(explainer.model, stream)
assert path.exists(model_name)
assert path.exists(explainer_name)

Expand All @@ -69,7 +69,7 @@ def test_serialize_mimic_lightgbm(self):
surrogate_name = 'surrogate_model.joblib'
tree_explainer_name = 'tree_explainer_model.joblib'
with open(model_name, 'wb') as stream:
dump(explainer.model.predict_proba, stream)
dump(explainer.model, stream)
with open(surrogate_name, 'wb') as stream:
dump(explainer.surrogate_model.model, stream)
with open(tree_explainer_name, 'wb') as stream:
Expand Down

0 comments on commit 2f0d575

Please sign in to comment.