From b3d35ec9c6eb0a27640982b4f1c7c9e968907d2d Mon Sep 17 00:00:00 2001 From: "Yaxun (Sam) Liu" Date: Sun, 25 Feb 2024 23:06:03 -0500 Subject: [PATCH] Add HIP test host-min-max --- External/HIP/CMakeLists.txt | 1 + External/HIP/host-min-max.hip | 69 +++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 External/HIP/host-min-max.hip diff --git a/External/HIP/CMakeLists.txt b/External/HIP/CMakeLists.txt index dda0ad4f8d..ce53c99cf0 100644 --- a/External/HIP/CMakeLists.txt +++ b/External/HIP/CMakeLists.txt @@ -12,6 +12,7 @@ macro(create_local_hip_tests VariantSuffix) list(APPEND HIP_LOCAL_TESTS empty) list(APPEND HIP_LOCAL_TESTS with-fopenmp) list(APPEND HIP_LOCAL_TESTS saxpy) + list(APPEND HIP_LOCAL_TESTS host-min-max) list(APPEND HIP_LOCAL_TESTS InOneWeekend) list(APPEND HIP_LOCAL_TESTS TheNextWeek) diff --git a/External/HIP/host-min-max.hip b/External/HIP/host-min-max.hip new file mode 100644 index 0000000000..f25ed1ea09 --- /dev/null +++ b/External/HIP/host-min-max.hip @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include // For std::abort +#include // For typeid + +std::string demangle(const char* name) { + if (std::string(name) == "i") return "int"; + else if (std::string(name) == "f") return "float"; + else if (std::string(name) == "d") return "double"; + else if (std::string(name) == "j") return "unsigned int"; + else if (std::string(name) == "l") return "long"; + else if (std::string(name) == "m") return "unsigned long"; + else if (std::string(name) == "x") return "long long"; + else if (std::string(name) == "y") return "unsigned long long"; + else return std::string(name); +} + +void checkHipCall(hipError_t status, const char* msg) { + if (status != hipSuccess) { + std::cerr << "HIP Error: " << msg << " - " << hipGetErrorString(status) << std::endl; + std::abort(); + } +} + +template +void compareResults(T1 hipResult, T2 stdResult, const std::string& testName) { + using CommonType = typename std::common_type::type; + if (static_cast(hipResult) != static_cast(stdResult)) { + std::cerr << testName << " mismatch: HIP result " << hipResult << " (" << demangle(typeid(hipResult).name()) << "), std result " << stdResult << " (" << demangle(typeid(stdResult).name()) << ")" << std::endl; + std::abort(); + } +} + +template +void runTest(T1 a, T2 b) { + std::cout << "\nTesting with values: " << a << " (" << demangle(typeid(a).name()) << ") and " << b << " (" << demangle(typeid(b).name()) << ")" << std::endl; + + // Using std::min and std::max explicitly for host code to ensure clarity and correctness + using CommonType = typename std::common_type::type; + CommonType stdMinResult = std::min(a, b); + CommonType stdMaxResult = std::max(a, b); + std::cout << "Host std::min result: " << stdMinResult << " (Type: " << demangle(typeid(stdMinResult).name()) << ")" << std::endl; + std::cout << "Host std::max result: " << stdMaxResult << " (Type: " << demangle(typeid(stdMaxResult).name()) << ")" << std::endl; + + // Using HIP's global min/max functions + CommonType hipMinResult = min(a, b); // Note: This directly uses HIP's min, assuming it's correctly overloaded for host code + CommonType hipMaxResult = max(a, b); // Note: This directly uses HIP's max, assuming it's correctly overloaded for host code + std::cout << "Host HIP min result: " << hipMinResult << " (Type: " << demangle(typeid(hipMinResult).name()) << ")" << std::endl; + std::cout << "Host HIP max result: " << hipMaxResult << " (Type: " << demangle(typeid(hipMaxResult).name()) << ")" << std::endl; + + // Ensure the host HIP and std results match + compareResults(hipMinResult, stdMinResult, "HIP vs std min"); + compareResults(hipMaxResult, stdMaxResult, "HIP vs std max"); +} + +int main() { + checkHipCall(hipSetDevice(0), "hipSetDevice failed"); + + runTest(10uLL, -5LL); // Testing with unsigned int and long long + runTest(-15, 20u); // Testing with int and unsigned int + runTest(2147483647, 2147483648u); // Testing with int and unsigned int + runTest(-922337203685477580LL, 922337203685477580uLL); // Testing with long long and unsigned long long + runTest(2.5f, 3.14159); // Testing with float and double + + std::cout << "\nPass\n"; // Output "Pass" at the end if all tests pass without aborting + return 0; +}