Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install google.protobuf and protoc in containers #910

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,13 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
## Add helper scripts for profiling with Nsight Systems
##
## The scripts saved to /opt/jax_nsys are embedded in the output archives
## written by nsys-jax, while the nsys-jax and nsys-jax-gather-src-files
## utilities are only used inside the container.
## written by nsys-jax, while the nsys-jax wrapper is used inside the container.
###############################################################################

ADD nsys-jax nsys-jax-gather-src-files /usr/local/bin
ADD nsys-jax /usr/local/bin
ADD jax_nsys/ /opt/jax_nsys
ADD requirements-nsys-jax.in /opt/pip-tools.d/
RUN ln -s /opt/jax_nsys/nsys-jax-ensure-protobuf /usr/local/bin/
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/

###############################################################################
## Copy manifest file to the container
Expand Down
65 changes: 65 additions & 0 deletions .github/container/jax_nsys/install-protoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python
import argparse
import google.protobuf
import io
import os
import platform
import requests
import zipfile

parser = argparse.ArgumentParser(
"Install a version of the protoc compiler that is compatible with the google.protobuf runtime"
)
parser.add_argument(
"prefix", help="Output prefix under which to install protoc", type=str
)
args = parser.parse_args()

s = requests.Session()
s.mount("https://", requests.adapters.HTTPAdapter(max_retries=5))

# protobuf versioning is complicated, see protocolbuffers/protobuf#11123 for more
# discussion. For older versions, when the versioning scheme was aligned, try and
# install a protoc with the same version as google.protobuf. For newer versions, given
# google.protobuf version X.Y.Z install protoc version Y.Z as described in
# https://protobuf.dev/support/version-support
runtime_version = tuple(map(int, google.protobuf.__version__.split(".")))
if runtime_version < (3, 21):
# old versioning scheme, try and install a matching protoc version
protoc_version = runtime_version
else:
# new versioning scheme, runtime minor.patch should be the protoc version
protoc_version = runtime_version[1:]

# Install the given protobuf version
ver = ".".join(map(str, protoc_version))
system = platform.system().lower()
machine = platform.machine()
system = {"darwin": "osx"}.get(system, system)
machine = {
"aarch64": "aarch_64",
"arm64": "aarch_64",
}.get(machine, machine)
# Apple Silicon can handle universal and x86_64 if it needs to.
machines = {
("osx", "aarch_64"): ["aarch_64", "universal_binary", "x86_64"],
}.get((system, machine), [machine])
for machine in machines:
r = s.get(
f"https://github.com/protocolbuffers/protobuf/releases/download/v{ver}/protoc-{ver}-{system}-{machine}.zip"
)
if r.status_code == 404:
# assume this means the architecture is not available
continue
else:
r.raise_for_status()

with zipfile.ZipFile(io.BytesIO(r.content)) as z:
for name in z.namelist():
if ".." in name:
continue
if name.startswith("bin/") or name.startswith("include/"):
z.extract(name, path=args.prefix)

# Make sure the protoc binary is executable
os.chmod(os.path.join(args.prefix, "bin", "protoc"), 0o755)
5 changes: 4 additions & 1 deletion .github/container/jax_nsys/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ if [[ ! -d "${VIRTUALENV}" ]]; then
virtualenv -p 3.12 -p 3.11 -p 3.10 "$@" "${VIRTUALENV}"
. "${VIRTUALENV}/bin/activate"
python -m pip install -U pip
"${SCRIPT_DIR}/nsys-jax-ensure-protobuf"
if ! python -c "import google.protobuf" > /dev/null 2>&1 || ! command -v protoc > /dev/null; then
python -m pip install protobuf requests
"${SCRIPT_DIR}/install-protoc" "${VIRTUALENV}"
fi
# matplotlib is a dependency of Analysis.ipynb but not jax_nsys
python -m pip install jupyterlab matplotlib
python -m pip install -e "${SCRIPT_DIR}/python/jax_nsys"
Expand Down
78 changes: 0 additions & 78 deletions .github/container/jax_nsys/nsys-jax-ensure-protobuf

This file was deleted.

36 changes: 22 additions & 14 deletions .github/container/nsys-jax
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ def find_pb_files_in_tmp(tmp_dir):
)


def gather_source_files(proto_dir, pb_file_prefix, pb_file_list, output_queue):
def gather_source_files(
proto_dir, proto_files, pb_file_prefix, pb_file_list, output_queue
):
"""
Given a directory containing the required .proto files (`proto_dir`) and a
prefix (`pb_file_prefix`) and list of relative paths to .pb files
Expand All @@ -364,19 +366,24 @@ def gather_source_files(proto_dir, pb_file_prefix, pb_file_list, output_queue):
hlo_pb_files = [
osp.join(pb_file_prefix, x) for x in pb_file_list if x.endswith(".hlo.pb")
]
# Delegate to another script to extract the list of source files referenced
# by .hlo.pb files. This pattern allows the other script to setup
# google.protobuf and protoc if needed to get a consistent installation.
p = subprocess.Popen(
["nsys-jax-gather-src-files", proto_dir] + hlo_pb_files,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
stdout, stderr = p.communicate()
if p.returncode != 0:
raise Exception("Gathering source code failed: " + stderr)
src_files = stdout.splitlines()
with tempfile.TemporaryDirectory() as tmp_dir:
# Compile the .proto files
subprocess.run(
["protoc", f"-I={proto_dir}", f"--python_out={tmp_dir}"] + proto_files,
check=True,
cwd=proto_dir,
)
# Collect the set of referenced source files
sys.path.insert(0, tmp_dir)
from xla.service import hlo_pb2

hlo = hlo_pb2.HloProto()
src_files = set()
for hlo_pb_file in hlo_pb_files:
with open(hlo_pb_file, "rb") as f:
hlo.ParseFromString(f.read())
src_files |= set(hlo.hlo_module.stack_frame_index.file_names)
sys.path.remove(tmp_dir)
if len(src_files) == 0:
raise Exception("No source files were gathered")
# Copy these files into the output archive.
Expand Down Expand Up @@ -449,6 +456,7 @@ def process_pb_and_proto_files(pb_future, proto_future, output_queue, futures):
executor.submit(
gather_source_files,
proto_dir,
proto_files,
pb_file_prefix,
pb_file_list,
files_to_archive,
Expand Down
37 changes: 0 additions & 37 deletions .github/container/nsys-jax-gather-src-files

This file was deleted.

4 changes: 4 additions & 0 deletions .github/container/pip-finalize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ fi
pip-sync --pip-args '--no-deps --src /opt' requirements.txt

rm -rf ~/.cache/*

# protobuf will be installed at least due to requirements-nsys-jax.in in the base
# image, but the installed version is likely to be influenced by other packages.
install-protoc /usr/local
4 changes: 4 additions & 0 deletions .github/container/requirements-nsys-jax.in
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# No version constraint; a compatible version of protoc will be installed later
protobuf
# Used by install-protoc
requests
virtualenv
6 changes: 3 additions & 3 deletions .github/workflows/nsys-jax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ jobs:
run: |
pip install virtualenv
virtualenv venv
- name: "Run nsys-jax-ensure-protobuf"
- name: "Install google.protobuf and protoc"
run: |
. ./venv/bin/activate
./JAX-Toolbox/.github/container/jax_nsys/nsys-jax-ensure-protobuf
./venv/bin/pip install -r ./JAX-Toolbox/.github/container/requirements-nsys-jax.in
./venv/bin/python ./JAX-Toolbox/.github/container/jax_nsys/install-protoc ./venv
- name: "Install jax_nsys Python package"
run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys
- name: "Install mypy"
Expand Down
Loading