Skip to content

Commit

Permalink
Try hermetic CUDA builds (matching system installed CUDA and cuDNN ve…
Browse files Browse the repository at this point in the history
…rsion) (#999)

This is only a temporary WAR to unblock the JAX and DLFW image builds.
The resulting image sizes do not change, compared to the ones we built
using system CUDA libs.
  • Loading branch information
gpupuck committed Aug 17, 2024
1 parent d7bad4b commit 3f35bf3
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,13 @@ export TF_CUDNN_PATHS=/usr/lib/$(uname -p)-linux-gnu
export TF_CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3-4)
export TF_CUDA_MAJOR_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3)
export TF_CUBLAS_VERSION=$(ls /usr/local/cuda/lib64/libcublas.so.*.*.* | cut -d . -f 3)
export TF_CUDNN_VERSION=$(echo "${NV_CUDNN_VERSION}" | cut -d . -f 1)
export TF_NCCL_VERSION=$(echo "${NCCL_VERSION}" | cut -d . -f 1)

TF_CUDNN_MAJOR_VERSION=$(grep "#define CUDNN_MAJOR" /usr/include/cudnn_version.h | awk '{print $3}')
TF_CUDNN_MINOR_VERSION=$(grep "#define CUDNN_MINOR" /usr/include/cudnn_version.h | awk '{print $3}')
TF_CUDNN_PATCHLEVEL_VERSION=$(grep "#define CUDNN_PATCHLEVEL" /usr/include/cudnn_version.h | awk '{print $3}')
export TF_CUDNN_VERSION="${TF_CUDNN_MAJOR_VERSION}.${TF_CUDNN_MINOR_VERSION}.${TF_CUDNN_PATCHLEVEL_VERSION}"

case "${CPU_ARCH}" in
"amd64")
export CC_OPT_FLAGS="-march=sandybridge -mtune=broadwell"
Expand Down Expand Up @@ -262,10 +266,8 @@ time python "${SRC_PATH_JAX}/build/build.py" \
--enable_cuda \
--build_gpu_plugin \
--gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \
--cuda_path=$TF_CUDA_PATHS \
--cudnn_path=$TF_CUDNN_PATHS \
--cuda_version=$TF_CUDA_VERSION \
--cudnn_version=$TF_CUDNN_VERSION \
--bazel_options=--repo_env=HERMETIC_CUDA_VERSION=$CUDA_VERSION \
--bazel_options=--repo_env=HERMETIC_CUDNN_VERSION=$TF_CUDNN_VERSION \
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
--enable_nccl=true \
--bazel_options=--linkopt=-fuse-ld=lld \
Expand All @@ -282,8 +284,7 @@ if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
#bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
python build/build.py --requirements_update --python_version=${PYTHON_VERSION}
bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
popd
fi
## Install the built packages
Expand All @@ -295,12 +296,8 @@ else
pip uninstall -y jaxlib jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin
fi

# install jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin
# install jax
if [[ "$JAXLIB_ONLY" == "0" ]]; then
pip --disable-pip-version-check install -e "${SRC_PATH_JAX}"
fi
# install jax and jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}"

# after installation (example)
# jax 0.4.32.dev20240808+9c2caedab /opt/jax
Expand Down

0 comments on commit 3f35bf3

Please sign in to comment.