diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 67bde37a33..1a1f84c884 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" +import importlib import os import re from enum import Enum @@ -31,7 +32,6 @@ import onnxruntime as ort from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss -from ..utils.import_utils import _is_package_available if TYPE_CHECKING: @@ -91,9 +91,11 @@ def is_onnxruntime_training_available(): def is_cupy_available(): """ - Checks if onnxruntime-training is available. + Checks if CuPy is available. """ - return _is_package_available("cupy") + # Don't use _is_package_available as it doesn't work with CuPy installed + # with `cupy-cuda*` and `cupy-rocm-*` package name (prebuilt wheels). + return importlib.util.find_spec("cupy") is not None class ORTConfigManager: