Skip to content

Commit

Permalink
Dockerfile.jax: install JAX for TE build
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Aug 14, 2024
1 parent d7bad4b commit 5ae63ac
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ ARG BUILD_DATE
## Build JAX
###############################################################################

FROM ${BASE_IMAGE} as builder
FROM ${BASE_IMAGE} AS builder
ARG URLREF_JAX
ARG URLREF_TRANSFORMER_ENGINE
ARG URLREF_XLA
ARG SRC_PATH_JAX
ARG SRC_PATH_TRANSFORMER_ENGINE
ARG SRC_PATH_XLA
ARG BAZEL_CACHE
ARG BUILD_PATH_JAXLIB
Expand Down Expand Up @@ -54,14 +56,24 @@ RUN build-jax.sh \
--xla-arm64-patch /opt/xla-arm64-neon.patch \
--clean

## Transformer engine: check out source and build wheel
RUN <<"EOF" bash -ex -o pipefail
pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
pip install -e ${SRC_PATH_JAX}
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF

###############################################################################
## Pack jaxlib wheel and various source dirs into a pre-installation image
###############################################################################

ARG BASE_IMAGE
FROM ${BASE_IMAGE} as mealkit
FROM ${BASE_IMAGE} AS mealkit
ARG URLREF_FLAX
ARG URLREF_TRANSFORMER_ENGINE
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG SRC_PATH_FLAX
Expand Down Expand Up @@ -102,20 +114,18 @@ git-clone.sh ${URLREF_FLAX} ${SRC_PATH_FLAX}
echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in
EOF

## Transformer engine: check out source and build wheel
# Copy TransformerEngine wheel from the builder stage
ENV NVTE_FRAMEWORK=jax
ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE}
RUN <<"EOF" bash -ex -o pipefail
pip install ninja && rm -rf ~/.cache/pip
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
python setup.py bdist_wheel && rm -rf build
echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" >> /opt/pip-tools.d/requirements-te.in
COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
RUN <<"EOF" bash -ex
ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl
echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" > /opt/pip-tools.d/requirements-te.in
EOF

###############################################################################
## Install primary packages and transitive dependencies
###############################################################################

FROM mealkit as final
FROM mealkit AS final
RUN pip-finalize.sh

0 comments on commit 5ae63ac

Please sign in to comment.