Skip to content

Temp XLA JAX T5x Perf test workflow #12

Temp XLA JAX T5x Perf test workflow

Temp XLA JAX T5x Perf test workflow #12

name: Temp XLA JAX T5x Perf test workflow
on:
workflow_dispatch:
inputs:
ARCHITECTURE:
type: string
required: true
BUILD_DATE:
type: string
description: Build date in YYYY-MM-DD format
required: false
default: NOT SPECIFIED
PUBLISH:
type: boolean
description: Publish dated images and update the 'latest' tag?
default: false
required: false
BUMP_MANIFEST:
type: boolean
description: Bump manifest file?
default: false
required: false
JAX_SRC:
description: 'JAX source url'
type: string
required: false
default: ''
JAX_REF:
description: 'JAX branch/commit SHA'
type: string
required: false
default: ''
XLA_SRC:
description: 'XLA source url'
type: string
required: false
default: ''
XLA_REF:
description: 'XLA branch/commit SHA'
type: string
required: false
default: ''
T5X_SRC:
description: 'T5X source url'
type: string
required: false
default: ''
T5X_REF:
description: 'T5X branch/commit SHA'
type: string
required: false
default: ''
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container
jobs:
metadata:
runs-on: ubuntu-22.04
outputs:
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }}
PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }}
BUMP_MANIFEST: ${{ steps.if-bump-manifest.outputs.BUMP_MANIFEST }}
steps:
- name: Set build date
id: date
shell: bash -x -e {0}
run: |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
- name: Determine whether results will be 'published'
id: if-publish
shell: bash -x -e {0}
run: |
echo "PUBLISH=${{ github.event_name == 'schedule' || inputs.PUBLISH }}" >> $GITHUB_OUTPUT
- name: Determine whether need to bump manifest
id: if-bump-manifest
shell: bash -x -e {0}
run: |
echo "BUMP_MANIFEST=${{ github.event_name == 'schedule' || inputs.BUMP_MANIFEST }}" >> $GITHUB_OUTPUT
build-base:
uses: ./.github/workflows/_build_base.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BUMP_MANIFEST: ${{ inputs.BUMP_MANIFEST }}
secrets: inherit
build-jax:
needs: build-base
uses: ./.github/workflows/_build_jax.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAG }}
secrets: inherit
build-t5x:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' # T5X arm64 build is wip in PR 252
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: "artifact-t5x-build"
BADGE_FILENAME: "badge-t5x-build"
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: upstream-t5x
DOCKERFILE: .github/container/Dockerfile.t5x
secrets: inherit
test-distribution:
runs-on: ubuntu-22.04
strategy:
matrix:
TEST_SCRIPT:
- extra-only-distribution.sh
- mirror-only-distribution.sh
- upstream-only-distribution.sh
fail-fast: false
steps:
- name: Print environment variables
run: env
- name: Set git login for tests
run: |
git config --global user.email "[email protected]"
git config --global user.name "JAX-Toolbox CI"
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v3
- name: Run integration test ${{ matrix.TEST_SCRIPT }}
run: bash rosetta/tests/${{ matrix.TEST_SCRIPT }}
test-upstream-t5x:
needs: build-t5x
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_t5x.yaml
with:
T5X_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
publish-target-tags:
runs-on: ubuntu-22.04
outputs:
TARGET_TAGS: ${{ steps.tags.outputs.TARGET_TAGS}}
steps:
- id: tags
run: |
declare -a TARGET_IMAGE=("jax", "test-upstream-t5x")
declare -a FLAVOR=("mealkit" "final")
## now loop through the above array
JSON="{"
for target in "${TARGET_IMAGE[@]}";do
for flavor in "${FLAVOR[@]}"; do
CONTAINER_TAG=${flavor}
TAG_DATED=${flavor}
if [[ ${flavor} == "final" ]]; then
CONTAINER_TAG=latest
TAG_DATED=nightly
fi
JSON=$(echo ${JSON}\"${target}-${flavor}-container-tag\":\"${CONTAINER_TAG}\",)
JSON=$(echo ${JSON}\"${target}-${flavor}-tag-dated\":\"${TAG_DATED}\",)
done
done
JSON="${JSON::-1} }"
echo "TARGET_TAGS=${JSON}" | tee -a $GITHUB_OUTPUT
publish:
needs: [metadata, test-upstream-t5x, publish-target-tags]
if: false # TODO: enable this after new image renaming proposal is approved
# if: ${{ !cancelled() && needs.metadata.outputs.PUBLISH }}
strategy:
fail-fast: false
matrix:
TARGET_IMAGE: [jax, test-upstream-t5x]
FLAVOR: [mealkit, final]
uses: ./.github/workflows/_publish_container.yaml
with:
SOURCE_IMAGE: |
${{ fromJson(needs.amd64.outputs.CONTAINER_TAGS)[format('tag-{0}-{1}', matrix.TARGET_IMAGE, matrix.FLAVOR)] }}
${{ fromJson(needs.arm64.outputs.CONTAINER_TAGS)[format('tag-{0}-{1}', matrix.TARGET_IMAGE, matrix.FLAVOR)] }}
TARGET_IMAGE: ${{ matrix.TARGET_IMAGE }}
TARGET_TAGS: |
type=raw,value=${{ fromJson(needs.publish-target-tags.outputs.TARGET_TAGS)[format('{0}-{1}-container-tag', matrix.TARGET_IMAGE, matrix.FLAVOR)] }},priority=500
type=raw,value=${{ fromJson(needs.publish-target-tags.outputs.TARGET_TAGS)[format('{0}-{1}-tag-dated', matrix.TARGET_IMAGE, matrix.FLAVOR)] }}-${{ needs.metadata.outputs.BUILD_DATE }},priority=500
finalize:
needs: [metadata, test-upstream-t5x, publish-target-tags]
if: "!cancelled()"
uses: ./.github/workflows/_finalize.yaml
with:
PUBLISH_BADGE: false
secrets: inherit