Skip to content

Commit

Permalink
reduce evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 23, 2024
1 parent 8c21ef4 commit 098efbe
Showing 1 changed file with 17 additions and 65 deletions.
82 changes: 17 additions & 65 deletions examples/graphbolt/rgcn/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,80 +6,32 @@
torch = None


### Evaluator for node property prediction
class IGB_Evaluator:
def __init__(self, name, num_tasks, eval_metric):
def __init__(self, name, num_tasks):
self.name = name
self.num_tasks = num_tasks
self.eval_metric = eval_metric

def _parse_and_check_input(self, input_dict):
if self.eval_metric == "acc":
if not "y_true" in input_dict:
raise RuntimeError("Missing key of y_true")
if not "y_pred" in input_dict:
raise RuntimeError("Missing key of y_pred")
def _parse_input(self, input_dict):
y_true, y_pred = input_dict["y_true"], input_dict["y_pred"]

y_true, y_pred = input_dict["y_true"], input_dict["y_pred"]
if torch and isinstance(y_true, torch.Tensor):
y_true = y_true.cpu().numpy()
if torch and isinstance(y_pred, torch.Tensor):
y_pred = y_pred.cpu().numpy()

"""
y_true: numpy ndarray or torch tensor of shape (num_nodes num_tasks)
y_pred: numpy ndarray or torch tensor of shape (num_nodes num_tasks)
"""
if not isinstance(y_true, np.ndarray) or not isinstance(
y_pred, np.ndarray
):
raise RuntimeError("Arguments must be numpy arrays")

# converting to torch.Tensor to numpy on cpu
if torch is not None and isinstance(y_true, torch.Tensor):
y_true = y_true.detach().cpu().numpy()
if y_true.shape != y_pred.shape or y_true.ndim != 2:
raise RuntimeError("Shape mismatch between y_true and y_pred")

if torch is not None and isinstance(y_pred, torch.Tensor):
y_pred = y_pred.detach().cpu().numpy()

## check type
if not (
isinstance(y_true, np.ndarray)
and isinstance(y_true, np.ndarray)
):
raise RuntimeError(
"Arguments to Evaluator need to be either numpy ndarray or torch tensor"
)

if not y_true.shape == y_pred.shape:
raise RuntimeError(
"Shape of y_true and y_pred must be the same"
)

if not y_true.ndim == 2:
raise RuntimeError(
"y_true and y_pred must to 2-dim arrray, {}-dim array given".format(
y_true.ndim
)
)

if not y_true.shape[1] == self.num_tasks:
raise RuntimeError(
"Number of tasks for {} should be {} but {} given".format(
self.name, self.num_tasks, y_true.shape[1]
)
)

return y_true, y_pred

else:
raise ValueError("Undefined eval metric %s " % (self.eval_metric))
return y_true, y_pred

def _eval_acc(self, y_true, y_pred):
acc_list = []

for i in range(y_true.shape[1]):
is_labeled = y_true[:, i] == y_true[:, i]
correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
acc_list.append(float(np.sum(correct)) / len(correct))

return {"acc": sum(acc_list) / len(acc_list)}
return {"acc": np.mean(np.all(y_true == y_pred, axis=1))}

def eval(self, input_dict):
if self.eval_metric == "acc":
y_true, y_pred = self._parse_and_check_input(input_dict)
return self._eval_acc(y_true, y_pred)
else:
raise ValueError("Undefined eval metric %s " % (self.eval_metric))
y_true, y_pred = self._parse_input(input_dict)
return self._eval_acc(y_true, y_pred)

0 comments on commit 098efbe

Please sign in to comment.