diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index bcae62890..5e028998c 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -19,10 +19,12 @@ ARG BUILD_DATE ## Build JAX ############################################################################### -FROM ${BASE_IMAGE} as builder +FROM ${BASE_IMAGE} AS builder ARG URLREF_JAX +ARG URLREF_TRANSFORMER_ENGINE ARG URLREF_XLA ARG SRC_PATH_JAX +ARG SRC_PATH_TRANSFORMER_ENGINE ARG SRC_PATH_XLA ARG BAZEL_CACHE ARG BUILD_PATH_JAXLIB @@ -54,14 +56,24 @@ RUN build-jax.sh \ --xla-arm64-patch /opt/xla-arm64-neon.patch \ --clean +## Transformer engine: check out source and build wheel +RUN <<"EOF" bash -ex -o pipefail +pip install ninja && rm -rf ~/.cache/pip +# TransformerEngine now needs JAX at build time +pip install -e ${SRC_PATH_JAX} +git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} +pushd ${SRC_PATH_TRANSFORMER_ENGINE} +python setup.py bdist_wheel && rm -rf build +ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist" +EOF + ############################################################################### ## Pack jaxlib wheel and various source dirs into a pre-installation image ############################################################################### ARG BASE_IMAGE -FROM ${BASE_IMAGE} as mealkit +FROM ${BASE_IMAGE} AS mealkit ARG URLREF_FLAX -ARG URLREF_TRANSFORMER_ENGINE ARG SRC_PATH_JAX ARG SRC_PATH_XLA ARG SRC_PATH_FLAX @@ -102,20 +114,18 @@ git-clone.sh ${URLREF_FLAX} ${SRC_PATH_FLAX} echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in EOF -## Transformer engine: check out source and build wheel +# Copy TransformerEngine wheel from the builder stage ENV NVTE_FRAMEWORK=jax ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE} -RUN <<"EOF" bash -ex -o pipefail -pip install ninja && rm -rf ~/.cache/pip -git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} -pushd ${SRC_PATH_TRANSFORMER_ENGINE} -python setup.py bdist_wheel && rm -rf build -echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" >> /opt/pip-tools.d/requirements-te.in +COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} +RUN <<"EOF" bash -ex +ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl +echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" > /opt/pip-tools.d/requirements-te.in EOF ############################################################################### ## Install primary packages and transitive dependencies ############################################################################### -FROM mealkit as final +FROM mealkit AS final RUN pip-finalize.sh