From 8fe2e5c11169d3cbb25e9e996a63408acb417028 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 13 Apr 2020 10:34:03 -0400 Subject: [PATCH] fixed timeseries error where index column not removed for surrogate model when using reset_teacher on evaluation (#205) --- .../mimic/mimic_explainer.py | 3 +++ test/test_mimic_explainer.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/interpret_community/mimic/mimic_explainer.py b/python/interpret_community/mimic/mimic_explainer.py index fed63a2a..92e5171d 100644 --- a/python/interpret_community/mimic/mimic_explainer.py +++ b/python/interpret_community/mimic/mimic_explainer.py @@ -445,6 +445,9 @@ def _get_explain_local_kwargs(self, evaluation_examples): if self._shap_values_output == ShapValuesOutput.TEACHER_PROBABILITY: # Outputting shap values in terms of the probabilities of the teacher model probabilities = self.function(original_evaluation_examples) + # if index column should not be set on surrogate model, remove it + if self.reset_index == ResetIndex.ResetTeacher: + evaluation_examples.set_index() if self._timestamp_featurizer: evaluation_examples.apply_timestamp_featurizer(self._timestamp_featurizer) if self._column_indexer: diff --git a/test/test_mimic_explainer.py b/test/test_mimic_explainer.py index 999a5980..4fd49c07 100644 --- a/test/test_mimic_explainer.py +++ b/test/test_mimic_explainer.py @@ -286,33 +286,38 @@ def _timeseries_generated_data(self): return X_train, X_test, y_train, y_test, time_column_name def test_datetime_features(self, mimic_explainer): - X_train, _, _, _, _ = self._timeseries_generated_data() + X_train, x_test, _, _, _ = self._timeseries_generated_data() kwargs = {'reset_index': 'reset'} model = DataFrameTestModel(X_train.copy()) features = list(X_train.columns.values) + list(X_train.index.names) mimic_explainer(model, X_train, LGBMExplainableModel, features=features, **kwargs) + # Note: need to fix column names after featurization as more columns are added to surrogate model def test_datetime_features_ignore(self, mimic_explainer): # Validate we throw when reset_index is set to ignore - X_train, _, _, _, _ = self._timeseries_generated_data() + X_train, x_test, _, _, _ = self._timeseries_generated_data() kwargs = {'reset_index': 'ignore'} model = DataFrameTestModel(X_train.copy()) - features = list(X_train.columns.values) + list(X_train.index.names) + features = list(X_train.columns.values) # Validate we hit the assertion error on the DataFrameTestModel for checking the presence of index column with pytest.raises(AssertionError): mimic_explainer(model, X_train, LGBMExplainableModel, features=features, **kwargs) # Validate we don't hit error if we disable the index column asserts model = DataFrameTestModel(X_train.copy(), assert_index_present=False) - mimic_explainer(model, X_train, LGBMExplainableModel, features=features, **kwargs) + explainer = mimic_explainer(model, X_train, LGBMExplainableModel, features=features, **kwargs) + explanation = explainer.explain_global(x_test) + assert explanation.method == LIGHTGBM_METHOD def test_datetime_features_already_featurized(self, mimic_explainer): # Validate we still passthrough underlying index to teacher model # even if we don't use it for surrogate model - X_train, _, _, _, _ = self._timeseries_generated_data() + X_train, x_test, _, _, _ = self._timeseries_generated_data() kwargs = {'reset_index': 'reset_teacher'} model = DataFrameTestModel(X_train.copy()) - features = list(X_train.columns.values) + list(X_train.index.names) - mimic_explainer(model, X_train, LGBMExplainableModel, features=features, **kwargs) + features = list(X_train.columns.values) + explainer = mimic_explainer(model, X_train, LGBMExplainableModel, features=features, **kwargs) + explanation = explainer.explain_global(x_test) + assert explanation.method == LIGHTGBM_METHOD def test_explain_model_imbalanced_classes(self, mimic_explainer): model = retrieve_model('unbalanced_model.pkl')