Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add rosetta-maxtext #738

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,12 @@ jax-triton:
mode: git-clone
maxtext:
url: https://github.com/google/maxtext.git
mirror_url: https://github.com/nvjax-svc-0/maxtext.git
tracking_ref: main
latest_verified_commit: 78daad198544def8274dbd656d122fbe6a0e1129
mode: git-clone
patches:
mirror/patch/test_rosetta_maxtext: file://patches/maxtext/mirror-patch-rosetta-maxtext.patch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving a reminder that this can be cleaned up if not needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not see the patch file in the repo. Are we using it?

levanter:
url: https://github.com/stanford-crfm/levanter.git
tracking_ref: main
Expand Down
437 changes: 225 additions & 212 deletions .github/workflows/_ci.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we did not run the tests for rosetta-maxtext yet, right? we should check the validity of the rosetta build

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ on:
EXTRA_TEST_ARGS:
type: string
description: Extra command line args to pass to test-maxtext.sh
default: ""
default: ''
required: false
BADGE_FILENAME:
type: string
description: 'Name of the endpoint JSON file for shields.io badge'
required: false
default: 'badge-maxtext-test.json'
default: 'badge-upstream-maxtext-test.json'
ARTIFACT_NAME:
type: string
description: 'Name of the artifact zip file'
Expand All @@ -34,12 +34,11 @@ on:
value: ${{ jobs.sitrep.outputs.STATUS }}

jobs:

single-process-multi-device:
strategy:
matrix:
PARALLEL_CONFIG:
- [1, 1, 2, 4]
- [1, 1, 2, 4]
# - [1, 1, 1, 8] # PP, DP, FSDP, TP
fail-fast: false

Expand Down Expand Up @@ -183,12 +182,12 @@ jobs:
strategy:
matrix:
PARALLEL_CONFIG:
- [1, 1, 1, 1]
- [1, 1, 8, 1]
- [1, 1, 1, 8]
- [1, 1, 4, 2]
- [1, 2, 2, 2]
- [1, 4, 2, 2]
- [1, 1, 1, 1]
- [1, 1, 8, 1]
- [1, 1, 1, 8]
- [1, 1, 4, 2]
- [1, 2, 2, 2]
- [1, 4, 2, 2]
fail-fast: false

runs-on: ubuntu-22.04
Expand Down Expand Up @@ -366,7 +365,7 @@ jobs:

sitrep:
needs: [single-process-multi-device, maxtext-multinode, metrics]
if: "!cancelled()"
if: '!cancelled()'
uses: ./.github/workflows/_sitrep_mgmn.yaml
secrets: inherit
with:
Expand All @@ -377,7 +376,7 @@ jobs:
summary:
runs-on: ubuntu-22.04
needs: [single-process-multi-device, maxtext-multinode]
if: "!cancelled()"
if: '!cancelled()'
steps:
- name: Generate TensorBoard query URL
run: |
Expand All @@ -394,7 +393,7 @@ jobs:
outcome:
needs: sitrep
runs-on: ubuntu-22.04
if: "!cancelled()"
if: '!cancelled()'
steps:
- name: Sets workflow status based on test outputs
run: |
Expand Down
39 changes: 19 additions & 20 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: CI

on:
schedule:
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC
pull_request:
types:
- opened
Expand All @@ -25,7 +25,7 @@ on:
required: false
MERGE_BUMPED_MANIFEST:
type: boolean
description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch"
description: '(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch'
default: false
required: false

Expand All @@ -34,16 +34,15 @@ concurrency:
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}

permissions:
contents: write # to fetch code and push branch
actions: write # to cancel previous workflows
packages: write # to upload container
pull-requests: write # to make pull request for manifest bump
contents: write # to fetch code and push branch
actions: write # to cancel previous workflows
packages: write # to upload container
pull-requests: write # to make pull request for manifest bump

env:
DEFAULT_MANIFEST_ARTIFACT_NAME: bumped-manifest

jobs:

metadata:
runs-on: ubuntu-22.04
outputs:
Expand Down Expand Up @@ -81,7 +80,7 @@ jobs:
id: manifest-branch
shell: bash -x -e {0}
run: |
BUMP_MANIFEST=${{ github.event_name == 'schedule' || inputs.BUMP_MANIFEST || 'false' }}
BUMP_MANIFEST=${{ github.event_name == 'schedule' || inputs.BUMP_MANIFEST || 'true' }}
MERGE_BUMPED_MANIFEST=${{ github.event_name == 'schedule' || inputs.MERGE_BUMPED_MANIFEST || 'false' }}
# Prepend nightly manifest branch with "z" to make it appear at the end
if [[ "$BUMP_MANIFEST" == "true" ]]; then
Expand Down Expand Up @@ -115,7 +114,7 @@ jobs:
shell: bash -x -e {0}
run: |
bash bump.sh --input-manifest manifest.yaml --output-manifest manifest.yaml.new --base-patch-dir ./patches-new

- name: Maybe replace current manifest/patches with the new one and show diff
working-directory: .github/container
shell: bash -x -e {0}
Expand Down Expand Up @@ -168,12 +167,11 @@ jobs:
steps:
- name: "Tests Succeeded: ${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}"
id: test_result
run:
echo "SUCCEEDED=${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" | tee -a $GITHUB_OUTPUT
run: echo "SUCCEEDED=${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" | tee -a $GITHUB_OUTPUT

- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4

- name: Delete checked-out manifest and patches
run: |
rm .github/container/manifest.yaml
Expand All @@ -185,7 +183,7 @@ jobs:
name: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
path: .github/container/

- name: "Create local manifest branch: ${{ needs.metadata.outputs.MANIFEST_BRANCH }}"
- name: 'Create local manifest branch: ${{ needs.metadata.outputs.MANIFEST_BRANCH }}'
id: local_branch
shell: bash -x -e {0}
run: |
Expand Down Expand Up @@ -213,7 +211,7 @@ jobs:
git merge --ff-only ${{ needs.metadata.outputs.MANIFEST_BRANCH }}
# Push the new change
git push origin ${{ github.ref_name }}

# We will create a Draft PR & remote branch if:
# 1. The tests failed
# 2. The merge failed
Expand Down Expand Up @@ -244,12 +242,12 @@ jobs:
draft: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: "Log created PR: #${{ fromJson(steps.create_pr.outputs.data).number }}"

- name: 'Log created PR: #${{ fromJson(steps.create_pr.outputs.data).number }}'
if: steps.create_pr.outcome == 'success'
run: |
echo "https://github.com/NVIDIA/JAX-Toolbox/pull/${{ fromJson(steps.create_pr.outputs.data).number }}" | tee -a $GITHUB_STEP_SUMMARY

# Guard delete in simple check to protect other branches
- name: Check that the branch matches znightly- prefix
run: |
Expand All @@ -271,7 +269,7 @@ jobs:

make-publish-configs:
runs-on: ubuntu-22.04
if: ${{ !cancelled() }}
if: ${{ !cancelled() }}
env:
MEALKIT_IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax-mealkit' || 'mock-jax-mealkit' }}
FINAL_IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax' || 'mock-jax' }}
Expand All @@ -294,6 +292,7 @@ jobs:
levanter
upstream-t5x
upstream-pax
upstream-maxtext
t5x
pax
grok
Expand Down Expand Up @@ -365,7 +364,7 @@ jobs:
needs:
- metadata
- make-publish-configs
if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }}
if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }}
strategy:
fail-fast: false
matrix: ${{ fromJson(needs.make-publish-configs.outputs.PUBLISH_CONFIGS) }}
Expand All @@ -381,7 +380,7 @@ jobs:

finalize:
needs: [metadata, amd64, arm64, publish-containers]
if: "!cancelled()"
if: '!cancelled()'
uses: ./.github/workflows/_finalize.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a debug line? Does it need to be committed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it. I found it convenient to have local testing directory which isn't checked into git.

27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,38 @@
<tr>
<td>
<picture>
<img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=MaxText%3D%7Bcore%2CMaxText%7D">
<img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Upstream MaxText%3D%7Bcore%2CMaxText%7D">
</picture>
</td>
<td>
<code>ghcr.io/nvidia/jax:upstream-maxtext</code>
</td>
<td>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-upstream-maxtext-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-maxtext-build-amd64.json&logo=docker&label=amd64"></a>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-maxtext-build-arm64.json&logo=docker&label=arm64">
</td>
<td>
<picture>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-maxtext-test.json&logo=nvidia&label=A100%20distributed">
</picture>
</td>
</tr>
<tr>
<td>
<picture>
<img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Rosetta MaxText%3D%7Bcore%2CMaxText%7D">
</picture>
</td>
<td>
<code>ghcr.io/nvidia/jax:maxtext</code>
</td>
<td>
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-maxtext-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-maxtext-build-amd64.json&logo=docker&label=amd64"></a>
<!-- <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-maxtext-build-arm64.json&logo=docker&label=arm64"> -->
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-amd64.json&logo=docker&label=amd64"></a>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-arm64.json&logo=docker&label=arm64">
</td>
<td>
<picture>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-maxtext-test.json&logo=nvidia&label=A100%20distributed">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-maxtext-test.json&logo=nvidia&label=A100%20distributed">
</picture>
</td>
</tr>
Expand Down
75 changes: 75 additions & 0 deletions rosetta/Dockerfile.maxtext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we again building TE here? The base image should be jax-mealkit:jax, right? And then we apply patch and build the final

Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:upstream-maxtext
ARG [email protected]
ARG GIT_USER_NAME=NVIDIA
# If set to "true", then will pull new local patches, the manifest.yaml and create-distribution.sh (in case it was updated).
# This is useful for development if you run `./bump.sh -i manifest.yaml` manually and do not want to trigger a full rebuild all
# the way up to the jax build.
ARG UPDATE_PATCHES=false
# It is common for TE developers to test a different TE against the LLM application. This is a knob to override what's in the manifest
# Accepts git-ref's from NVIDIA/TransformerEngine or pull requests (pull/$number/head)
ARG UPDATED_TE_REF=""

# Rosetta and optionally patches are pulled from this
FROM scratch AS jax-toolbox

###############################################################################
### Download source and add auxiliary scripts
################################################################################

FROM ${BASE_IMAGE} AS mealkit
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME
ARG UPDATE_PATCHES
ARG UPDATED_TE_REF

ENV ENABLE_TE=1

RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu
MANIFEST_DIR=$(dirname ${MANIFEST_FILE})
if [[ "${UPDATE_PATCHES}" != "true" && "${UPDATE_PATCHES}" != "false" ]]; then
echo "UPDATE_PATCHES can only be true or false"
exit 1
fi
if [[ "${UPDATE_PATCHES}" == "true" ]]; then
cp -r /mnt/jax-toolbox/.github/container/patches ${MANIFEST_DIR}/
cp /mnt/jax-toolbox/.github/container/manifest.yaml ${MANIFEST_DIR}/manifest.yaml
cp /mnt/jax-toolbox/.github/container/create-distribution.sh ${MANIFEST_DIR}/create-distribution.sh
fi
cp -r /mnt/jax-toolbox/rosetta /opt/rosetta

if [[ -n "${UPDATED_TE_REF}" ]]; then
TE_INSTALL_DIR=/opt/transformer-engine
yq e ".transformer-engine.latest_verified_commit = \"${UPDATED_TE_REF}\"" -i $MANIFEST_FILE
# Install from source instead of pre-built wheel
sed -i -E 's@( file:///opt/transformer-engine)/dist/[^ ]*@\1@' /opt/pip-tools.d/requirements-te.in
git -C $TE_INSTALL_DIR fetch -a
if [[ "${UPDATED_TE_REF}" =~ ^pull/ ]]; then
PR_ID=$(cut -d/ -f2 <<<"${UPDATED_TE_REF}")
git -C $TE_INSTALL_DIR fetch origin ${UPDATED_TE_REF}:PR-${PR_ID}
git -C $TE_INSTALL_DIR checkout PR-${PR_ID}
else
git -C $TE_INSTALL_DIR checkout ${UPDATED_TE_REF}
fi
fi

# Setting the username/email is required to author commits from patches
git config --global user.email "${GIT_USER_EMAIL}"
git config --global user.name "${GIT_USER_NAME}"

bash ${MANIFEST_DIR}/create-distribution.sh \
--manifest ${MANIFEST_FILE} \
--package maxtext
# Remove .gitconfig to avoid end-user authoring commits as the "build user"
rm -f ~/.gitconfig
EOF

WORKDIR /opt/rosetta

###############################################################################
### Install accumulated packages from the base image and the previous stage
################################################################################

FROM mealkit as final

RUN pip-finalize.sh
Loading