Dockerfile.jax: install JAX for TE build #182
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: nsys-jax pure-Python CI | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | |
on: | |
pull_request: | |
types: | |
- opened | |
- reopened | |
- ready_for_review | |
- synchronize | |
paths-ignore: | |
- '**.md' | |
push: | |
branches: | |
- main | |
env: | |
NSYS_JAX_PYTHON_FILES: | | |
JAX-Toolbox/.github/container/nsys-jax | |
JAX-Toolbox/.github/container/nsys-jax-combine | |
JAX-Toolbox/.github/container/jax_nsys | |
jobs: | |
mypy: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
with: | |
path: JAX-Toolbox | |
sparse-checkout: | | |
.github/container | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: "Create virtual environment" | |
run: | | |
pip install virtualenv | |
virtualenv venv | |
- name: "Install google.protobuf and protoc" | |
run: | | |
./venv/bin/pip install -r ./JAX-Toolbox/.github/container/requirements-nsys-jax.in | |
./venv/bin/python ./JAX-Toolbox/.github/container/jax_nsys/install-protoc ./venv | |
- name: "Install jax_nsys Python package" | |
run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys | |
- name: "Install mypy" | |
run: ./venv/bin/pip install matplotlib mypy nbconvert types-protobuf | |
- name: "Fetch XLA .proto files" | |
uses: actions/checkout@v4 | |
with: | |
path: xla | |
repository: openxla/xla | |
sparse-checkout: | | |
*.proto | |
sparse-checkout-cone-mode: false | |
- name: "Compile .proto files" | |
shell: bash -x -e {0} | |
run: | | |
mkdir compiled_protos compiled_stubs protos | |
mv -v xla/third_party/tsl/tsl protos/ | |
mv -v xla/xla protos/ | |
./venv/bin/python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" | |
touch compiled_stubs/py.typed | |
- name: "Convert .ipynb to .py" | |
shell: bash -x -e {0} | |
run: | | |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do | |
./venv/bin/jupyter nbconvert --to script ${notebook} | |
done | |
- name: "Run mypy checks" | |
shell: bash -x -e {0} | |
run: | | |
export MYPYPATH="${PWD}/compiled_stubs" | |
./venv/bin/mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} | |
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered | |
# notebook from here too. These input files were generated with something like | |
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 | |
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\ | |
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID} | |
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess | |
# with newer nsys-jax components bind-mounted in. | |
combine: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes | |
shell: bash -x -e {0} | |
# TODO: when nsys-jax-combine supports --nsys-jax-analysis, exercise that here. | |
run: .github/container/nsys-jax-combine -o .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip | |
- name: Mock up the structure of an extracted .zip file | |
run: unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip | |
- name: "Setup Python 3.12" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.12' | |
- name: Run the install script, but skip launching Jupyter Lab | |
shell: bash -x -e {0} | |
run: | | |
pip install virtualenv | |
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh | |
- name: Test the Jupyter Lab installation and execute the notebook | |
shell: bash -x -e {0} | |
run: | | |
pushd .github/container/jax_nsys | |
./nsys_jax_venv/bin/python -m jupyterlab --version | |
# Run with ipython for the sake of getting a clear error message | |
./nsys_jax_venv/bin/ipython Analysis.ipynb | |
# This input file was generated with something like | |
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 | |
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\ | |
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh | |
# --steps=5 --fsdp=4 | |
notebook: | |
env: | |
# TODO: these could/should be saved in the repository settings instead | |
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1 | |
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975 | |
# Name/bash regex for shields.io endpoint JSON files | |
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)' | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: Mock up the structure of an extracted .zip file | |
shell: bash -x -e {0} | |
run: | | |
# Get the actual test data from a real archive, minus the .nsys-rep file | |
unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: Run the install script, but skip launching Jupyter Lab | |
shell: bash -x -e {0} | |
run: | | |
pip install virtualenv | |
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh | |
- name: Test the Jupyter Lab installation and execute the notebook | |
shell: bash -x -e {0} | |
run: | | |
pushd .github/container/jax_nsys | |
./nsys_jax_venv/bin/python -m jupyterlab --version | |
# Run with ipython for the sake of getting a clear error message | |
./nsys_jax_venv/bin/ipython Analysis.ipynb | |
- name: Render the notebook | |
id: render | |
shell: bash -x -e {0} | |
run: | | |
pushd .github/container/jax_nsys | |
workdir=$(mktemp -d) | |
./nsys_jax_venv/bin/jupyter nbconvert --execute --inplace Analysis.ipynb | |
cp *.ipynb *.svg "${workdir}" | |
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT | |
- name: Upload rendered notebook to Gist | |
id: upload | |
uses: actions/github-script@v7 | |
with: | |
github-token: ${{ secrets.NVJAX_GIST_TOKEN }} | |
script: | | |
const currentDateTime = new Date().toISOString(); | |
const gistDescription = | |
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` + | |
`Run ID: ${{ github.run_id }}, ` + | |
`Repository: ${{ github.repository }}, ` + | |
`Event: ${{ github.event_name }}, ` + | |
`Created: ${currentDateTime}`; | |
const fs = require('fs').promises; | |
const workdir = '${{ steps.render.outputs.WORKDIR }}' | |
const files = await fs.readdir(workdir); | |
gist = await github.rest.gists.create({ | |
description: gistDescription, | |
public: false, | |
files: Object.fromEntries( | |
await Promise.all( | |
files.map( | |
async filename => { | |
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8'); | |
return [filename, { content }]; | |
} | |
) | |
) | |
) | |
}); | |
console.log(gist) | |
return gist.data.id; | |
- name: Copy rendered notebook to Gist with well-known ID | |
uses: actions/github-script@v7 | |
with: | |
github-token: ${{ secrets.NVJAX_GIST_TOKEN }} | |
script: | | |
const srcId = ${{ steps.upload.outputs.result }}; | |
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}"; | |
const { PUBLISH_NOTEBOOK_FILES } = process.env; | |
// Fetch existing files from destination gist | |
const { data: dstData } = await github.rest.gists.get({ | |
gist_id: dstId | |
}); | |
// Mark existing files in destination gist for deletion | |
let filesToUpdate = {}; | |
for (const filename of Object.keys(dstData.files)) { | |
filesToUpdate[filename] = null; | |
} | |
// Fetch files from source gist | |
const { data: srcData } = await github.rest.gists.get({ | |
gist_id: srcId | |
}); | |
// Add or update files based on the pattern | |
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`); | |
for (const [filename, fileObj] of Object.entries(srcData.files)) { | |
if (filename.match(pattern)) { | |
// If the total gist size is too large, not all the content will have | |
// been returned and we need some extra requests. | |
if (fileObj.truncated) { | |
const { data } = await github.request(fileObj.raw_url) | |
filesToUpdate[filename] = { | |
content: new TextDecoder().decode(data) | |
}; | |
} else { | |
filesToUpdate[filename] = { | |
content: fileObj.content | |
}; | |
} | |
} | |
} | |
// Update files in destination gist | |
await github.rest.gists.update({ | |
gist_id: dstId, | |
files: filesToUpdate | |
}); | |
console.log("Files copied successfully."); | |
console.log(Object.keys(filesToUpdate)); | |
ruff: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
with: | |
path: JAX-Toolbox | |
sparse-checkout: | | |
.github/container | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: "Install ruff" | |
run: pip install ruff | |
- name: "Run ruff checks" | |
shell: bash -x {0} | |
run: | | |
ruff check ${NSYS_JAX_PYTHON_FILES} | |
check_status=$? | |
ruff format --check ${NSYS_JAX_PYTHON_FILES} | |
format_status=$? | |
if [[ $format_status != 0 || $check_status != 0 ]]; then | |
exit 1 | |
fi |