Skip to content

Commit

Permalink
Fix a bug in refine.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yangwenzhuo08 committed Aug 25, 2022
1 parent ef432ed commit 622a5f9
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions omnixai/explainers/tabular/counterfactual/mace/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
import numpy as np
import pandas as pd
from typing import Dict, Callable, Union

Expand Down Expand Up @@ -45,7 +44,7 @@ def _refine(

for col, (a, b) in cont_features.items():
gap, r = b - a, None
while (b - a) / (gap + 1e-3) > 0.1:
while (b - a) / gap > 0.1:
z = (a + b) * 0.5
y.iloc[0, column2loc[col]] = z
scores = predict_function(Tabular(data=y, categorical_columns=instance.categorical_columns))[0]
Expand Down Expand Up @@ -83,8 +82,8 @@ def refine(
cont_features = {}
for col in self.cont_columns:
a, b = float(x[col].values[0]), float(y[col].values[0])
if a != b:
cont_features[col] = (a, b) if a <= b else (b, a)
if abs(a - b) > 1e-6:
cont_features[col] = (a, b)
if len(cont_features) == 0:
results.append(y)
else:
Expand Down

0 comments on commit 622a5f9

Please sign in to comment.