Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc improve #111

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/source/guide/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,32 @@ Even if you have CUDA 12, your CUDA version might still be lower than the versio
In this case, try to install `jax[cuda11]`.
```

### AMD GPU (ROCM)

Despite being considered experimental, installing AMD GPUs for ROCm is surprisingly straightforward thanks to their open-source drivers. However, currently only a limited number of GPUs are supported, notably the Radeon RX 7900XTX and Radeon PRO W7900 for consumer-grade GPUs. Note that Windows is not currently supported.

#### Install GPU driver

Since the AMD driver is open-source, installation is simplified: simply install mesa through your Linux distribution's package manager. In many cases, the driver may already be pre-installed.

To verify that the driver is installed, run the following command:

```bash
lsmod | grep amdgpu
```

And you should see `amdgpu` in the output.

#### Install ROCm

The latest version of ROCm (v5.7.1 or later) may not be available in your Linux distribution's package manager. Therefore, using a containerized environment is the easiest way to get started.

```bash
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined rocm/jax:latest
```

Please visit [Docker Hub](https://hub.docker.com/r/rocm/jax) for further instructions.

## Verify your installation

Open a Python terminal, and run the following:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/miscellaneous/high_vram_usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# High VRAM usage

By default, JAX will allocate 75% of the GPU memory regardless of the program you run.
This preallocate is used to avoid memory fragmentation and improve performance.

To disable this behavior, you can use the `XLA_PYTHON_CLIENT_PREALLOCATE=false` environment variable.

For more information, please refer to the [JAX documentation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html).
8 changes: 8 additions & 0 deletions docs/source/miscellaneous/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Miscellaneous

```{toctree}
:maxdepth: 1

selecting_gpu
high_vram_usage
```
19 changes: 19 additions & 0 deletions docs/source/miscellaneous/selecting_gpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Selecting GPU

To run your program on a specific GPU, you can use the `CUDA_VISIBLE_DEVICES` environment variable. For example, to run your program on the second GPU, you can use:

```bash
CUDA_VISIBLE_DEVICES=1 python my_program.py
```

To run your program on multiple GPUs, you can use:

```bash
CUDA_VISIBLE_DEVICES=0,1 python my_program.py
```

To disable GPU usage, you can use:

```bash
CUDA_VISIBLE_DEVICES="" python my_program.py
```