diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index 1aaccbe69..d174c098b 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -166,9 +166,13 @@ export TF_CUDNN_PATHS=/usr/lib/$(uname -p)-linux-gnu export TF_CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3-4) export TF_CUDA_MAJOR_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3) export TF_CUBLAS_VERSION=$(ls /usr/local/cuda/lib64/libcublas.so.*.*.* | cut -d . -f 3) -export TF_CUDNN_VERSION=$(echo "${NV_CUDNN_VERSION}" | cut -d . -f 1) export TF_NCCL_VERSION=$(echo "${NCCL_VERSION}" | cut -d . -f 1) +TF_CUDNN_MAJOR_VERSION=$(grep "#define CUDNN_MAJOR" /usr/include/cudnn_version.h | awk '{print $3}') +TF_CUDNN_MINOR_VERSION=$(grep "#define CUDNN_MINOR" /usr/include/cudnn_version.h | awk '{print $3}') +TF_CUDNN_PATCHLEVEL_VERSION=$(grep "#define CUDNN_PATCHLEVEL" /usr/include/cudnn_version.h | awk '{print $3}') +export TF_CUDNN_VERSION="${TF_CUDNN_MAJOR_VERSION}.${TF_CUDNN_MINOR_VERSION}.${TF_CUDNN_PATCHLEVEL_VERSION}" + case "${CPU_ARCH}" in "amd64") export CC_OPT_FLAGS="-march=sandybridge -mtune=broadwell" @@ -262,10 +266,8 @@ time python "${SRC_PATH_JAX}/build/build.py" \ --enable_cuda \ --build_gpu_plugin \ --gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \ - --cuda_path=$TF_CUDA_PATHS \ - --cudnn_path=$TF_CUDNN_PATHS \ - --cuda_version=$TF_CUDA_VERSION \ - --cudnn_version=$TF_CUDNN_VERSION \ + --bazel_options=--repo_env=HERMETIC_CUDA_VERSION=$CUDA_VERSION \ + --bazel_options=--repo_env=HERMETIC_CUDNN_VERSION=$TF_CUDNN_VERSION \ --cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \ --enable_nccl=true \ --bazel_options=--linkopt=-fuse-ld=lld \ @@ -282,8 +284,7 @@ if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))') - #bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}" - python build/build.py --requirements_update --python_version=${PYTHON_VERSION} + bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}" popd fi ## Install the built packages @@ -295,12 +296,8 @@ else pip uninstall -y jaxlib jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin fi -# install jaxlib -pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -# install jax -if [[ "$JAXLIB_ONLY" == "0" ]]; then - pip --disable-pip-version-check install -e "${SRC_PATH_JAX}" -fi +# install jax and jaxlib +pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}" # after installation (example) # jax 0.4.32.dev20240808+9c2caedab /opt/jax