From 622a5f90028550dd1533eb2c66ad9800106b7d00 Mon Sep 17 00:00:00 2001 From: yangwenzhuo08 Date: Thu, 25 Aug 2022 17:07:01 +0800 Subject: [PATCH] Fix a bug in refine.py --- omnixai/explainers/tabular/counterfactual/mace/refine.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/omnixai/explainers/tabular/counterfactual/mace/refine.py b/omnixai/explainers/tabular/counterfactual/mace/refine.py index 6c4e0f8f..298789cb 100644 --- a/omnixai/explainers/tabular/counterfactual/mace/refine.py +++ b/omnixai/explainers/tabular/counterfactual/mace/refine.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # -import numpy as np import pandas as pd from typing import Dict, Callable, Union @@ -45,7 +44,7 @@ def _refine( for col, (a, b) in cont_features.items(): gap, r = b - a, None - while (b - a) / (gap + 1e-3) > 0.1: + while (b - a) / gap > 0.1: z = (a + b) * 0.5 y.iloc[0, column2loc[col]] = z scores = predict_function(Tabular(data=y, categorical_columns=instance.categorical_columns))[0] @@ -83,8 +82,8 @@ def refine( cont_features = {} for col in self.cont_columns: a, b = float(x[col].values[0]), float(y[col].values[0]) - if a != b: - cont_features[col] = (a, b) if a <= b else (b, a) + if abs(a - b) > 1e-6: + cont_features[col] = (a, b) if len(cont_features) == 0: results.append(y) else: