From ff5b714f27922e6aa0a9954a1dea5c5ed0f0d80b Mon Sep 17 00:00:00 2001 From: Adarsh Yoga Date: Thu, 11 Apr 2024 13:01:29 -0500 Subject: [PATCH] fixing dpnp failure caused by addition of nin, nout and ntypes --- numba_dpex/dpnp_iface/dpnp_ufunc_db.py | 48 +++++++++++++++++++++----- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py index def71adb17..fae113ee16 100644 --- a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py +++ b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import copy +import logging import dpnp import numpy as np @@ -56,6 +57,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): # variable is passed by value from numba.np.ufunc_db import _ufunc_db + failed_dpnpop_types_lst = [] for ufuncop in dpnpdecl.supported_ufuncs: if ufuncop == "erf": op = getattr(dpnp, "erf") @@ -72,20 +74,48 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): "d->d": mathimpl.lower_ocl_impl[("erf", (_unary_d_d))], } else: - op = getattr(dpnp, ufuncop) + dpnpop = getattr(dpnp, ufuncop) npop = getattr(np, ufuncop) - op.nin = npop.nin - op.nout = npop.nout - op.nargs = npop.nargs - op.types = npop.types - op.is_dpnp_ufunc = True + if not hasattr(dpnpop, "nin"): + dpnpop.nin = npop.nin + if not hasattr(dpnpop, "nout"): + dpnpop.nout = npop.nout + if not hasattr(dpnpop, "nargs"): + dpnpop.nargs = dpnpop.nin + dpnpop.nout + + # Check for `types` attribute for dpnp op. + # AttributeError: + # If the `types` attribute is not present for dpnp op, + # use the `types` attribute from corresponding numpy op. + # ValueError: + # Store all dpnp ops that failed when `types` attribute + # is present but failure occurs when read. + # Log all failing dpnp outside this loop. + try: + dpnpop.types + except ValueError: + failed_dpnpop_types_lst.append(ufuncop) + except AttributeError: + dpnpop.types = npop.types + + dpnpop.is_dpnp_ufunc = True cp = copy.copy(_ufunc_db[npop]) - ufunc_db.update({op: cp}) - for key in list(ufunc_db[op].keys()): + ufunc_db.update({dpnpop: cp}) + for key in list(ufunc_db[dpnpop].keys()): if ( "FF->" in key or "DD->" in key or "F->" in key or "D->" in key ): - ufunc_db[op].pop(key) + ufunc_db[dpnpop].pop(key) + + if failed_dpnpop_types_lst: + try: + getattr(dpnp, failed_dpnpop_types_lst[0]).types + except ValueError: + ops = " ".join(failed_dpnpop_types_lst) + logging.exception( + "The types attribute for the following dpnp ops could not be " + f"determined: {ops}" + )