diff --git a/README.md b/README.md index 298bedc..3d8d43c 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ docker build \ Note: If you want to use ROCm 6.x, you need to switch to AMD version of pytorch docker as a base layer to build: ```bash docker build \ - -t opensplat:ubuntu-22.04-libtorch-torch-2.1.2-rocm-6.0.2 \ + -t opensplat:ubuntu-22.04-libtorch-2.1.2-rocm-6.0.2 \ -f Dockerfile.rocm6 . ``` @@ -116,10 +116,23 @@ To run on your own data, choose the path to an existing [COLMAP](https://colmap. There's several parameters you can tune. To view the full list: + ```bash ./opensplat --help ``` +To train a model with AMD GPU using docker container, you can use the following command as a reference: +1. Launch the docker container with the following command: +```bash +docker run -it -v ~/data:/data --device=/dev/kfd --device=/dev/dri opensplat:ubuntu-22.04-libtorch-2.1.2-rocm-6.0.2 bash +``` +2. Inside the docker container, run the following command to train the model: +```bash +export HIP_VISIBLE_DEVICES=0 +export HSA_OVERRIDE_GFX_VERSION=10.3.0 # AMD RX 6700 XT workaround +cd /code/build +./opensplat /data/banana -n 2000 +``` ## Project Goals We recently released OpenSplat, so there's lots of work to do. diff --git a/vendor/gsplat/reduce.cuh b/vendor/gsplat/reduce.cuh index 6782fc7..d4c27bb 100644 --- a/vendor/gsplat/reduce.cuh +++ b/vendor/gsplat/reduce.cuh @@ -2,13 +2,13 @@ #include #define MAX_INIT 0.0 -#define WARP_SIZE 32 +#define WARP_SIZE 64 namespace cg = cooperative_groups; __inline__ __device__ float warp_reduce_sum(float val, const int tile) { for ( int offset = tile / 2; offset > 0; offset /= 2 ) - val += __shfl_down(0xffffffff, val, offset); + val += __shfl_down(val, offset); return val; } @@ -39,7 +39,7 @@ __inline__ __device__ float block_reduce_sum(float val, const int tile) { __inline__ __device__ float warp_reduce_max(float val, const int tile) { for (int offset = tile / 2; offset > 0; offset /= 2) - val = max(val, __shfl_xor(0xffffffff, val, offset)); + val = max(val, __shfl_xor(val, offset)); return val; }