Skip to content

Commit

Permalink
Roll back CI changes for now
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Jul 8, 2024
1 parent e189133 commit eced285
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 32 deletions.
13 changes: 0 additions & 13 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -482,19 +482,6 @@ jobs:
PAX_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
secrets: inherit

test-nsys-jax:
needs: build-upstream-pax
if: inputs.ARCHITECTURE == 'amd64' # no arm64 gpu runners
uses: ./.github/workflows/_test_upstream_pax.yaml
with:
ARTIFACT_NAME: artifact-nsys-jax-test
BADGE_FILENAME: badge-nsys-jax-test.json
COMMAND_PREFIX: 'nsys-jax --nsys-jax-analysis summary -o ${PROFILE_OUTPUT} --'
FW_NAME: nsys-jax
PAX_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
STEPS: "20"
secrets: inherit

test-rosetta-pax:
needs: build-rosetta-pax
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
Expand Down
25 changes: 6 additions & 19 deletions .github/workflows/_test_upstream_pax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@ on:
description: 'Name of the framework being used'
required: false
default: 'upstream-pax'
COMMAND_PREFIX:
type: string
description: 'Prefix for the test command; can be used to enable profiling'
required: false
default: ''
STEPS:
type: string
description: 'Number of execution steps to test'
required: false
default: "500"
outputs:
TEST_STATUS:
description: 'Summary pass/fail value indicating if results from tests are acceptable'
Expand Down Expand Up @@ -113,18 +103,17 @@ jobs:
true
# run job with tasks on each node sharing one container
PROFILE_OUTPUT="/output/nsys-jax-${{ steps.meta.outputs.TEST_CASE_NAME }}-rank${SLURM_PROCID}"
time srun \
--ntasks=1 \
--ntasks-per-node=1 \
--container-name=runtime \
--container-mounts=${{ steps.meta.outputs.MODEL_PATH }}:/output \
--container-entrypoint \
${{ inputs.COMMAND_PREFIX }} test-pax.sh \
test-pax.sh \
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
--dtype bfloat16 \
--batch-per-gpu 4 \
--steps ${{ inputs.STEPS }} \
--steps 500 \
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \
--data-parallel ${{ matrix.PARALLEL_CONFIG[1] }} \
--fsdp ${{ matrix.PARALLEL_CONFIG[2] }} \
Expand Down Expand Up @@ -290,18 +279,17 @@ jobs:
true
# run job with tasks on each node sharing one container
PROFILE_OUTPUT="/output/nsys-jax-${{ steps.meta.outputs.TEST_CASE_NAME }}-rank${SLURM_PROCID}"
time srun \
--tasks=${{ steps.meta.outputs.TOTAL_TASKS }} \
--tasks-per-node=${{ steps.meta.outputs.GPUS_PER_NODE }} \
--container-name=runtime \
--container-mounts=${{ steps.meta.outputs.MODEL_PATH }}:/output \
--container-entrypoint \
${{ inputs.COMMAND_PREFIX }} test-pax.sh \
test-pax.sh \
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
--dtype bfloat16 \
--batch-per-gpu 4 \
--steps ${{ inputs.STEPS }} \
--steps 500 \
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \
--data-parallel ${{ matrix.PARALLEL_CONFIG[1] }} \
--fsdp ${{ matrix.PARALLEL_CONFIG[2] }} \
Expand Down Expand Up @@ -439,18 +427,17 @@ jobs:
true
# run job with tasks on each node sharing one container
PROFILE_OUTPUT="/output/nsys-jax-${{ steps.meta.outputs.TEST_CASE_NAME }}-rank${SLURM_PROCID}"
time srun \
--ntasks=${{ steps.meta.outputs.TOTAL_TASKS }} \
--ntasks-per-node=1 \
--container-name=runtime \
--container-mounts=${{ steps.meta.outputs.MODEL_PATH }}:/output \
--container-entrypoint \
${{ inputs.COMMAND_PREFIX }} test-pax.sh \
test-pax.sh \
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
--dtype bfloat16 \
--batch-per-gpu 4 \
--steps ${{ inputs.STEPS }} \
--steps 500 \
--evaluate \
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \
--data-parallel ${{ matrix.PARALLEL_CONFIG[1] }} \
Expand Down

0 comments on commit eced285

Please sign in to comment.