You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Launch docker run -it --name jaxrocm --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video rocm/jax-build:rocm6.0.0-jax0.4.23-py3.11.0-profiler-fix
Create test.py with the following content:
import jax
from jax import random, numpy as jnp
def main():
# Print JAX backend and available devices
print("JAX backend:", jax.lib.xla_bridge.get_backend().platform)
devices = jax.devices()
print("Available devices:", devices)
# Generate random matrices
key = random.PRNGKey(0)
x = random.normal(key, (5000, 5000), dtype=jnp.float32)
y = random.normal(key, (5000, 5000), dtype=jnp.float32)
# Matrix multiplication
result = jnp.dot(x, y)
# Print the shape of the result to confirm computation
print("Result shape:", result.shape)
if __name__ == "__main__":
main()
I get following output when I run this code:
2024-02-18 11:32:29.650365: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
JAX backend: cpu
Available devices: [CpuDevice(id=0)]
Result shape: (5000, 5000)
Why isn't it using GPU ?
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
ROCk module is loaded
=====================
HSA System Attributes
=====================
Runtime Version: 1.1
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
DMAbuf Support: YES
==========
HSA Agents
==========
*******
Agent 1
*******
Name: AMD Ryzen 7 PRO 7840U w/ Radeon 780M Graphics
Uuid: CPU-XX
Marketing Name: AMD Ryzen 7 PRO 7840U w/ Radeon 780M Graphics
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 5132
BDFID: 0
Internal Node ID: 0
Compute Unit: 16
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 31509624(0x1e0cc78) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 31509624(0x1e0cc78) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 31509624(0x1e0cc78) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:
*******
Agent 2
*******
Name: gfx1103
Uuid: GPU-XX
Marketing Name:
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 2048(0x800) KB
Chip ID: 5567(0x15bf)
ASIC Revision: 9(0x9)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 2700
BDFID: 25600
Internal Node ID: 1
Compute Unit: 12
SIMDs per CU: 2
Shader Engines: 1
Shader Arrs. per Eng.: 2
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 32(0x20)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 33
SDMA engine uCode:: 15
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 1048576(0x100000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 1048576(0x100000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx1103
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
FBarrier Max Size: 32
*** Done ***
Problem Description
JAX is not using GPU via ROCm.
Operating System
Ubuntu 20.04.6 LTS (Focal Fossa)
CPU
AMD Ryzen 7 PRO 7840U w/ Radeon 780M Graphics
GPU
AMD Radeon VII
ROCm Version
ROCm 6.0.0
ROCm Component
ROCm
Steps to Reproduce
docker run -it --name jaxrocm --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video rocm/jax-build:rocm6.0.0-jax0.4.23-py3.11.0-profiler-fix
test.py
with the following content:I get following output when I run this code:
Why isn't it using GPU ?
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
Additional Information
Output of
vulkaninfo --summary
:Output of
lspci -k | grep -EA3 'VGA|3D|Display'
:No response
The text was updated successfully, but these errors were encountered: