diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec396fbcc6117..d5c3af748e528 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -79,6 +79,9 @@ function(AddTest) if (onnxruntime_USE_NCCL) target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS}) endif() + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_CUDA_MINIMAL) + endif() endif() if (onnxruntime_USE_TENSORRT) # used for instantiating placeholder TRT builder to mitigate TRT library load/unload overhead diff --git a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc index f700e31003012..027d4b3fff1b0 100644 --- a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. // BiasDropout kernel is only implemented for CUDA/ROCM -#if defined(USE_CUDA) || defined(USE_ROCM) +#if (defined(USE_CUDA) && !defined(USE_CUDA_MINIMAL)) || defined(USE_ROCM) #ifdef _MSC_VER #pragma warning(disable : 4389) diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 4e9e80b180e9c..43d3782be3280 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -2078,6 +2078,7 @@ TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; int gather_count = 0; + ASSERT_GT(plan->execution_plan.size(), 1) << "Number of execution plans should be greater than 1"; for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex());