-
Notifications
You must be signed in to change notification settings - Fork 47
216 lines (205 loc) · 7.89 KB
/
nsys-jax.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
name: nsys-jax pure-Python CI
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
on:
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths-ignore:
- '**.md'
push:
branches:
- main
env:
NSYS_JAX_PYTHON_FILES: JAX-Toolbox/.github/container/nsys-jax JAX-Toolbox/.github/container/jax_nsys
jobs:
mypy:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Create virtual environment"
run: |
pip install virtualenv
virtualenv venv
- name: "Install google.protobuf and protoc"
run: |
./venv/bin/pip install -r ./JAX-Toolbox/.github/container/requirements-nsys-jax.in
./venv/bin/python ./JAX-Toolbox/.github/container/jax_nsys/install-protoc ./venv
- name: "Install jax_nsys Python package"
run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys
- name: "Install mypy"
run: ./venv/bin/pip install matplotlib mypy nbconvert types-protobuf
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
path: xla
repository: openxla/xla
sparse-checkout: |
*.proto
sparse-checkout-cone-mode: false
- name: "Compile .proto files"
shell: bash -x -e {0}
run: |
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
./venv/bin/python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
shell: bash -x -e {0}
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
./venv/bin/jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
shell: bash -x -e {0}
run: |
export MYPYPATH="${PWD}/compiled_stubs"
./venv/bin/mypy ${NSYS_JAX_PYTHON_FILES}
notebook:
env:
# TODO: these could/should be saved in the repository settings instead
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975
# Name/bash regex for shields.io endpoint JSON files
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb)'
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Mock up the structure of an extracted .zip file
shell: bash -x -e {0}
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/maxtext_fsdp4_test_data.zip
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
./nsys_jax_venv/bin/python -m jupyterlab --version
# Run with ipython for the sake of getting a clear error message
./nsys_jax_venv/bin/ipython Analysis.ipynb
- name: Render the notebook
id: render
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
workdir=$(mktemp -d)
./nsys_jax_venv/bin/jupyter nbconvert --execute --to notebook --output-dir=$workdir Analysis.ipynb
echo "WORKDIR=$workdir" >> $GITHUB_OUTPUT
- name: Upload rendered notebook to Gist
id: upload
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const currentDateTime = new Date().toISOString();
const gistDescription =
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` +
`Run ID: ${{ github.run_id }}, ` +
`Repository: ${{ github.repository }}, ` +
`Event: ${{ github.event_name }}, ` +
`Created: ${currentDateTime}`;
const fs = require('fs').promises;
const workdir = '${{ steps.render.outputs.WORKDIR }}'
const files = await fs.readdir(workdir);
gist = await github.rest.gists.create({
description: gistDescription,
public: false,
files: Object.fromEntries(
await Promise.all(
files.map(
async filename => {
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8');
return [filename, { content }];
}
)
)
)
});
console.log(gist)
return gist.data.id;
- name: Copy rendered notebook to Gist with well-known ID
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const srcId = ${{ steps.upload.outputs.result }};
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}";
const { PUBLISH_NOTEBOOK_FILES } = process.env;
// Fetch existing files from destination gist
const { data: dstData } = await github.rest.gists.get({
gist_id: dstId
});
// Mark existing files in destination gist for deletion
let filesToUpdate = {};
for (const filename of Object.keys(dstData.files)) {
filesToUpdate[filename] = null;
}
// Fetch files from source gist
const { data: srcData } = await github.rest.gists.get({
gist_id: srcId
});
// Add or update files based on the pattern
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`);
for (const [filename, fileObj] of Object.entries(srcData.files)) {
if (filename.match(pattern)) {
filesToUpdate[filename] = {
content: fileObj.content
};
}
}
// Update files in destination gist
await github.rest.gists.update({
gist_id: dstId,
files: filesToUpdate
});
console.log("Files copied successfully.");
console.log(Object.keys(filesToUpdate));
ruff:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Install ruff"
run: pip install ruff
- name: "Run ruff checks"
shell: bash -x {0}
run: |
ruff check ${NSYS_JAX_PYTHON_FILES}
check_status=$?
ruff format --check ${NSYS_JAX_PYTHON_FILES}
format_status=$?
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi