diff --git a/tests/test_serialize_explanation.py b/tests/test_serialize_explanation.py index 4da4187b..2efa9d61 100644 --- a/tests/test_serialize_explanation.py +++ b/tests/test_serialize_explanation.py @@ -90,8 +90,8 @@ def test_old_load_explanation_backcompat(self, iris, tabular_explainer, iris_svm explanation = explainer.explain_global(iris[DatasetConstants.X_TEST], include_local=False) loaded_explanation = load_explanation(os.path.join('.', 'tests', 'backcompat_explanation')) explanation._id = loaded_explanation._id - _assert_explanation_equivalence(explanation, loaded_explanation, rtol=1e-7) - _assert_numpy_explanation_types(explanation, loaded_explanation, rtol=1e-7) + _assert_explanation_equivalence(explanation, loaded_explanation, rtol=0.03, atol=0.002) + _assert_numpy_explanation_types(explanation, loaded_explanation, rtol=0.03, atol=0.002) def _generate_old_explanation(iris, tabular_explainer, iris_svm_model): @@ -104,7 +104,7 @@ def _generate_old_explanation(iris, tabular_explainer, iris_svm_model): save_explanation(explanation, path, exist_ok=False) -def _assert_explanation_equivalence(actual, expected, rtol=None): +def _assert_explanation_equivalence(actual, expected, rtol=None, atol=None): # get the non-null properties in the expected explanation paramkeys = filter(lambda x, expected=expected: hasattr(expected, getattr(ExplainParams, x)), list(ExplainParams.get_serializable())) @@ -122,36 +122,36 @@ def _assert_explanation_equivalence(actual, expected, rtol=None): else: expected_dataset = expected_value.original_dataset if issparse(actual_dataset) and issparse(expected_dataset): - _assert_sparse_data_equivalence(actual_dataset, expected_dataset, rtol=rtol) + _assert_sparse_data_equivalence(actual_dataset, expected_dataset, rtol=rtol, atol=atol) else: - _assert_allclose_or_eq(actual_dataset, expected_dataset, rtol=rtol) + _assert_allclose_or_eq(actual_dataset, expected_dataset, rtol=rtol, atol=atol) elif isinstance(actual_value, (np.ndarray, collections.abc.Sequence)): - _assert_allclose_or_eq(actual_value, expected_value, rtol=rtol) + _assert_allclose_or_eq(actual_value, expected_value, rtol=rtol, atol=atol) elif isinstance(actual_value, pd.DataFrame) and isinstance(expected_value, pd.DataFrame): - _assert_allclose_or_eq(actual_value.values, expected_value.values, rtol=rtol) + _assert_allclose_or_eq(actual_value.values, expected_value.values, rtol=rtol, atol=atol) elif issparse(actual_value) and issparse(expected_value): - _assert_sparse_data_equivalence(actual_value, expected_value, rtol=rtol) + _assert_sparse_data_equivalence(actual_value, expected_value, rtol=rtol, atol=atol) else: assert actual_value == expected_value -def _assert_allclose_or_eq(actual, expected, rtol=None): +def _assert_allclose_or_eq(actual, expected, rtol=None, atol=None): if rtol is not None: try: - return np.testing.assert_allclose(actual, expected, rtol=rtol) + return np.testing.assert_allclose(actual, expected, rtol=rtol, atol=atol) except TypeError: print("Caught type error, defaulting to regular compare") np.testing.assert_array_equal(actual, expected) -def _assert_sparse_data_equivalence(actual, expected, rtol=None): - _assert_allclose_or_eq(actual.data, expected.data, rtol=rtol) - _assert_allclose_or_eq(actual.indices, expected.indices, rtol=rtol) - _assert_allclose_or_eq(actual.indptr, expected.indptr, rtol=rtol) - _assert_allclose_or_eq(actual.shape, expected.shape, rtol=rtol) +def _assert_sparse_data_equivalence(actual, expected, rtol=None, atol=None): + _assert_allclose_or_eq(actual.data, expected.data, rtol=rtol, atol=atol) + _assert_allclose_or_eq(actual.indices, expected.indices, rtol=rtol, atol=atol) + _assert_allclose_or_eq(actual.indptr, expected.indptr, rtol=rtol, atol=atol) + _assert_allclose_or_eq(actual.shape, expected.shape, rtol=rtol, atol=atol) -def _assert_numpy_explanation_types(actual, expected, rtol=None): +def _assert_numpy_explanation_types(actual, expected, rtol=None, atol=None): # assert "_" variables equivalence if hasattr(actual, ExplainParams.get_private(ExplainParams.LOCAL_IMPORTANCE_VALUES)): assert(isinstance(actual._local_importance_values, np.ndarray)) @@ -161,14 +161,19 @@ def _assert_numpy_explanation_types(actual, expected, rtol=None): expected._local_importance_values) else: np.testing.assert_allclose(actual._local_importance_values, - expected._local_importance_values, rtol=rtol) + expected._local_importance_values, + rtol=rtol, + atol=atol) if hasattr(actual, ExplainParams.get_private(ExplainParams.EVAL_DATA)): assert(isinstance(actual._eval_data, np.ndarray)) assert(isinstance(expected._eval_data, np.ndarray)) if rtol is None: np.testing.assert_array_equal(actual._eval_data, expected._eval_data) else: - np.testing.assert_allclose(actual._eval_data, expected._eval_data, rtol=rtol) + np.testing.assert_allclose(actual._eval_data, + expected._eval_data, + rtol=rtol, + atol=atol) # performs serialization and de-serialization for any explanation