Skip to content

Commit

Permalink
[api] Disable CudaUtils logging in fork mode (#3340)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jul 22, 2024
1 parent 1c6911c commit b23c848
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public final class CudaUtils {
private static final CudaLibrary LIB = loadLibrary();

private static String[] gpuInfo;
private static boolean logging = true;

private CudaUtils() {}

Expand All @@ -64,6 +65,7 @@ public static int getGpuCount() {
try {
return Integer.parseInt(gpuInfo[0]);
} catch (NumberFormatException e) {
logger.warn("Unexpected output: {}", gpuInfo[0], e);
return 0;
}
}
Expand All @@ -77,17 +79,21 @@ public static int getGpuCount() {
case 0:
return count[0];
case CudaLibrary.ERROR_NO_DEVICE:
logger.debug(
"No GPU device found: {} ({})", LIB.cudaGetErrorString(result), result);
if (logging) {
logger.debug(
"No GPU device found: {} ({})", LIB.cudaGetErrorString(result), result);
}
return 0;
case CudaLibrary.INITIALIZATION_ERROR:
case CudaLibrary.INSUFFICIENT_DRIVER:
case CudaLibrary.ERROR_NOT_PERMITTED:
default:
logger.warn(
"Failed to detect GPU count: {} ({})",
LIB.cudaGetErrorString(result),
result);
if (logging) {
logger.warn(
"Failed to detect GPU count: {} ({})",
LIB.cudaGetErrorString(result),
result);
}
return 0;
}
}
Expand Down Expand Up @@ -209,6 +215,7 @@ public static MemoryUsage getGpuMemory(Device device) {
*/
@SuppressWarnings("PMD.SystemPrintln")
public static void main(String[] args) {
logging = false;
int gpuCount = getGpuCount();
if (args.length == 0) {
if (gpuCount <= 0) {
Expand Down Expand Up @@ -262,23 +269,33 @@ private static CudaLibrary loadLibrary() {
if (files != null && files.length > 0) {
String fileName = files[0].getName();
String cudaRt = fileName.substring(0, fileName.length() - 4);
logger.debug("Found cudart: {}", files[0].getAbsolutePath());
if (logging) {
logger.debug("Found cudart: {}", files[0].getAbsolutePath());
}
return Native.load(cudaRt, CudaLibrary.class);
}
}
logger.debug("No cudart library found in path.");
if (logging) {
logger.debug("No cudart library found in path.");
}
return null;
}
return Native.load("cudart", CudaLibrary.class);
} catch (UnsatisfiedLinkError e) {
logger.debug("cudart library not found.");
logger.trace("", e);
if (logging) {
logger.debug("cudart library not found.");
logger.trace("", e);
}
} catch (LinkageError e) {
logger.warn("You have a conflict version of JNA in the classpath.");
logger.debug("", e);
if (logging) {
logger.warn("You have a conflict version of JNA in the classpath.");
logger.debug("", e);
}
} catch (SecurityException e) {
logger.warn("Access denied during loading cudart library.");
logger.trace("", e);
if (logging) {
logger.warn("Access denied during loading cudart library.");
logger.trace("", e);
}
}
return null;
}
Expand Down

0 comments on commit b23c848

Please sign in to comment.