Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diagnostic tool to look for potential evaluation errors #1268

Merged
merged 15 commits into from
Nov 15, 2023
Merged
257 changes: 255 additions & 2 deletions idaes/core/util/model_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from operator import itemgetter
import sys
from inspect import signature
import math
from math import log
from typing import List, Sequence

import numpy as np
from scipy.linalg import svd
Expand All @@ -29,6 +31,7 @@

from pyomo.environ import (
Binary,
Integers,
Block,
check_optimal_termination,
ConcreteModel,
Expand All @@ -39,10 +42,21 @@
SolverFactory,
value,
Var,
is_fixed,
)
from pyomo.core.expr.numeric_expr import (
DivisionExpression,
NPV_DivisionExpression,
PowExpression,
NPV_PowExpression,
UnaryFunctionExpression,
NPV_UnaryFunctionExpression,
NumericExpression,
)
from pyomo.core.base.block import _BlockData
from pyomo.core.base.var import _VarData
from pyomo.core.base.var import _GeneralVarData, _VarData
from pyomo.core.base.constraint import _ConstraintData
from pyomo.repn.standard_repn import generate_standard_repn
from pyomo.common.collections import ComponentSet
from pyomo.common.config import (
ConfigDict,
Expand All @@ -52,9 +66,10 @@
)
from pyomo.util.check_units import identify_inconsistent_units
from pyomo.contrib.incidence_analysis import IncidenceGraphInterface
from pyomo.core.expr.visitor import identify_variables
from pyomo.core.expr.visitor import identify_variables, StreamBasedExpressionVisitor
from pyomo.contrib.pynumero.interfaces.pyomo_nlp import PyomoNLP
from pyomo.contrib.pynumero.asl import AmplInterface
from pyomo.contrib.fbbt.fbbt import compute_bounds_on_expr
from pyomo.common.deprecation import deprecation_warning
from pyomo.common.errors import PyomoException

Expand Down Expand Up @@ -1229,6 +1244,10 @@
warnings, next_steps = self._collect_structural_warnings()
cautions = self._collect_structural_cautions()

eval_error_warnings, eval_error_cautions = self._collect_potential_eval_errors()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Is there an easy way this could be moved inside the _collect_structural_warnings and _collect_structural_cautions methods? It looks like you did this for efficiency so you only need to walk the model once. However, there is an assert_no_structural_issues method as well (that only checks for warnings) which could not include this as written (although the solution might be to just add this code there as well). I am also working on adding diagnostics to the convergence tester (parameter sweep) tool, and I am looking at using these methods there too (which would mean they need to become public).
  2. Could you add an entry to next_steps as well pointing users to the associated display method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have addressed both of these now. I realized I was being silly with the cautions. They were completely unnecessary/wrong.

warnings.extend(eval_error_warnings)
cautions.extend(eval_error_cautions)

_write_report_section(
stream=stream, lines_list=stats, title="Model Statistics", header="="
)
Expand Down Expand Up @@ -1303,6 +1322,55 @@
footer="=",
)

def _collect_potential_eval_errors(self):
res = list()
warnings = list()
cautions = list()
for con in self.model.component_data_objects(
Constraint, active=True, descend_into=True
):
walker = _EvalErrorWalker()
con_warnings, con_cautions = walker.walk_expression(con.body)
for msg in con_warnings:
msg = f"{con.name}: " + msg
warnings.append(msg)
for msg in con_cautions:
msg = f"{con.name}: " + msg
cautions.append(msg)
for obj in self.model.component_data_objects(
Objective, active=True, descend_into=True
):
walker = _EvalErrorWalker()
obj_warnings, obj_cautions = walker.walk_expression(obj.expr)
for msg in obj_warnings:
msg = f"{obj.name}: " + msg
warnings.append(msg)
for msg in obj_cautions:
msg = f"{obj.name}: " + msg
cautions.append(msg)

Check warning on line 1350 in idaes/core/util/model_diagnostics.py

View check run for this annotation

Codecov / codecov/patch

idaes/core/util/model_diagnostics.py#L1349-L1350

Added lines #L1349 - L1350 were not covered by tests

return warnings, cautions

def display_potential_evaluation_errors(self, stream=None):
if stream is None:
stream = sys.stdout

Check warning on line 1356 in idaes/core/util/model_diagnostics.py

View check run for this annotation

Codecov / codecov/patch

idaes/core/util/model_diagnostics.py#L1356

Added line #L1356 was not covered by tests

warnings, cautions = self._collect_potential_eval_errors()
_write_report_section(
stream=stream,
lines_list=warnings,
title=f"{len(warnings)} WARNINGS",
line_if_empty="No warnings found!",
header="=",
)
_write_report_section(
stream=stream,
lines_list=cautions,
title=f"{len(cautions)} Cautions",
line_if_empty="No cautions found!",
footer="=",
)

@document_kwargs_from_configdict(SVDCONFIG)
def prepare_svd_toolbox(self, **kwargs):
"""
Expand Down Expand Up @@ -1623,6 +1691,191 @@
)


def _get_bounds_with_inf(node: NumericExpression):
lb, ub = compute_bounds_on_expr(node)
if lb is None:
lb = -math.inf
if ub is None:
ub = math.inf
return lb, ub


def _caution_expression_argument(
node: NumericExpression,
args_to_check: Sequence[NumericExpression],
caution_list: List[str],
):
should_caution = False
for arg in args_to_check:
if is_fixed(arg):
continue
if isinstance(arg, _GeneralVarData):
continue
should_caution = True
break
if should_caution:
msg = f"Potential evaluation error in {node}; "
msg += "arguments are expressions with bounds that are not strictly "
msg += "enforced;"
caution_list.append(msg)


def _check_eval_error_division(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
lb, ub = _get_bounds_with_inf(node.args[1])
if lb <= 0 <= ub:
msg = f"Potential division by 0 in {node}; Denominator bounds are ({lb}, {ub})"
warn_list.append(msg)
else:
_caution_expression_argument(node, [node.args[1]], caution_list)


def _check_eval_error_pow(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
arg1, arg2 = node.args
lb1, ub1 = _get_bounds_with_inf(arg1)
lb2, ub2 = _get_bounds_with_inf(arg2)

integer_domains = ComponentSet([Binary, Integers])

integer_exponent = False
# if the exponent is an integer, there should not be any evaluation errors
if isinstance(arg2, _GeneralVarData) and arg2.domain in integer_domains:
# The exponent is an integer variable
# check if the base can be zero
integer_exponent = True
if lb2 == ub2 and lb2 == round(lb2):
# The exponent is fixed to an integer
integer_exponent = True
repn = generate_standard_repn(arg2, quadratic=True)
if (
repn.nonlinear_expr is None
and repn.constant == round(repn.constant)
and all(i.domain in integer_domains for i in repn.linear_vars)
and all(i[0].domain in integer_domains for i in repn.quadratic_vars)
and all(i[1].domain in integer_domains for i in repn.quadratic_vars)
and all(i == round(i) for i in repn.linear_coefs)
and all(i == round(i) for i in repn.quadratic_coefs)
):
# The exponent is a linear or quadratic expression containing
# only integer variables with integer coefficients
integer_exponent = True

if integer_exponent and (lb1 > 0 or ub1 < 0):
# life is good; the exponent is an integer and the base is nonzero
return None
elif integer_exponent and lb2 >= 0:
# life is good; the exponent is a nonnegative integer
return None

# if the base is positive, there should not be any evaluation errors
if lb1 > 0:
_caution_expression_argument(node, node.args, caution_list)
return None
if lb1 >= 0 and lb2 >= 0:
_caution_expression_argument(node, node.args, caution_list)
return None

msg = f"Potential evaluation error in {node}; "
msg += f"base bounds are ({lb1}, {ub1}); "
msg += f"exponent bounds are ({lb2}, {ub2})"
warn_list.append(msg)


def _check_eval_error_log(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb <= 0:
msg = f"Potential log of a non-positive number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)
else:
_caution_expression_argument(node, node.args, caution_list)


def _check_eval_error_tan(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
lb, ub = _get_bounds_with_inf(node)
if not (math.isfinite(lb) and math.isfinite(ub)):
msg = f"{node} may evaluate to -inf or inf; Argument bounds are {_get_bounds_with_inf(node.args[0])}"
warn_list.append(msg)
else:
_caution_expression_argument(node, node.args, caution_list)


def _check_eval_error_asin(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of asin outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)
else:
_caution_expression_argument(node, node.args, caution_list)


def _check_eval_error_acos(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of acos outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)
else:
_caution_expression_argument(node, node.args, caution_list)


def _check_eval_error_sqrt(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < 0:
msg = f"Potential square root of a negative number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)
else:
_caution_expression_argument(node, node.args, caution_list)


_unary_eval_err_handler = dict()
_unary_eval_err_handler["log"] = _check_eval_error_log
_unary_eval_err_handler["log10"] = _check_eval_error_log
_unary_eval_err_handler["tan"] = _check_eval_error_tan
_unary_eval_err_handler["asin"] = _check_eval_error_asin
_unary_eval_err_handler["acos"] = _check_eval_error_acos
_unary_eval_err_handler["sqrt"] = _check_eval_error_sqrt


def _check_eval_error_unary(
node: NumericExpression, warn_list: List[str], caution_list: List[str]
):
if node.getname() in _unary_eval_err_handler:
_unary_eval_err_handler[node.getname()](node, warn_list, caution_list)


_eval_err_handler = dict()
_eval_err_handler[DivisionExpression] = _check_eval_error_division
_eval_err_handler[NPV_DivisionExpression] = _check_eval_error_division
_eval_err_handler[PowExpression] = _check_eval_error_pow
_eval_err_handler[NPV_PowExpression] = _check_eval_error_pow
_eval_err_handler[UnaryFunctionExpression] = _check_eval_error_unary
_eval_err_handler[NPV_UnaryFunctionExpression] = _check_eval_error_unary


class _EvalErrorWalker(StreamBasedExpressionVisitor):
def __init__(self):
super().__init__()
self._warn_list = list()
self._caution_list = list()

def exitNode(self, node, data):
if type(node) in _eval_err_handler:
_eval_err_handler[type(node)](node, self._warn_list, self._caution_list)
return self._warn_list, self._caution_list


# TODO: Rename and redirect once old DegeneracyHunter is removed
@document_kwargs_from_configdict(DHCONFIG)
class DegeneracyHunter2:
Expand Down
Loading
Loading