From 70f67a8bf1c012b488a5cf250d8c3619f6efa8f0 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 15 Oct 2024 11:36:15 +0200 Subject: [PATCH] Add automatic triage utility (#793) This is a small Python program that implements a two-stage bisection, identifying a commit in JAX or XLA that caused a test case to start failing. --- .github/triage/jax_toolbox_triage/__init__.py | 3 + .github/triage/jax_toolbox_triage/args.py | 112 +++++ .github/triage/jax_toolbox_triage/docker.py | 79 ++++ .github/triage/jax_toolbox_triage/logic.py | 297 ++++++++++++++ .github/triage/jax_toolbox_triage/main.py | 241 +++++++++++ .github/triage/jax_toolbox_triage/utils.py | 76 ++++ .github/triage/pyproject.toml | 8 + .github/triage/tests/test_triage_logic.py | 387 ++++++++++++++++++ .github/workflows/triage-ci.yaml | 84 ++++ docs/triage.md | 6 + 10 files changed, 1293 insertions(+) create mode 100644 .github/triage/jax_toolbox_triage/__init__.py create mode 100644 .github/triage/jax_toolbox_triage/args.py create mode 100644 .github/triage/jax_toolbox_triage/docker.py create mode 100644 .github/triage/jax_toolbox_triage/logic.py create mode 100755 .github/triage/jax_toolbox_triage/main.py create mode 100644 .github/triage/jax_toolbox_triage/utils.py create mode 100644 .github/triage/pyproject.toml create mode 100644 .github/triage/tests/test_triage_logic.py create mode 100644 .github/workflows/triage-ci.yaml diff --git a/.github/triage/jax_toolbox_triage/__init__.py b/.github/triage/jax_toolbox_triage/__init__.py new file mode 100644 index 000000000..21db616e0 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/__init__.py @@ -0,0 +1,3 @@ +from .main import main + +__all__ = ["main"] diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py new file mode 100644 index 000000000..d092e7200 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/args.py @@ -0,0 +1,112 @@ +import argparse +import datetime +import getpass +import os +import pathlib +import tempfile + + +def parse_args(): + parser = argparse.ArgumentParser( + description=""" + Triage failures in JAX/XLA-related tests. The expectation is that the given + test command is failing in recent versions, but that it passed in the past. The + script first triages the regression with a search of the nightly containers, + and then refines the search to a particular commit of JAX or XLA.""", + ) + + container_search_args = parser.add_argument_group( + title="Container-level search", + description=""" + First, it is verified that the test command fails on the given end date, unless + both --end-date and --skip-precondition-checks were passed. Then, the program + searches backwards to find a container when the given test did pass. The + --start-date option can be used to speed up this search, if you already know a + date on which the test was passing. The earliest failure is located to within + --threshold-days days.""", + ) + commit_search_args = parser.add_argument_group( + title="Commit-level search", + description=""" + Second, the failure is localised to a commit of JAX or XLA by re-building and + re-testing inside the earliest container that demonstrates the failure. At each + point, the oldest JAX commit that is newer than XLA is used.""", + ) + parser.add_argument( + "--container", + help=""" + Container to use. Example: jax, pax, triton. Used to construct the URLs of + nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""", + required=True, + ) + parser.add_argument( + "--output-prefix", + default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"), + help=""" + Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS. + An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is + written as PREFIX-debug.log, and a JSON summary is written as + PREFIX-summary.json""", + type=pathlib.Path, + ) + parser.add_argument( + "--skip-precondition-checks", + action="store_true", + help=""" + Skip checks that should pass by construction. This saves time, but may yield + incorrect results if you are not careful. Specifically this means that the test + is assumed to fail on --end-date (if specified), pass on --start-date (if + specified), and fail after recompilation in the earliest-known-failure + container. Careful use of this option, along with --start-date, --end-date and + --threshold-days, allows the container-level search to be skipped.""", + ) + parser.add_argument( + "test_command", + nargs="+", + help=""" + Command to execute inside the container. This should be as targeted as + possible.""", + ) + container_search_args.add_argument( + "--end-date", + help=""" + Initial estimate of the earliest nightly container date where the test case + fails. Defaults to the newest available nightly container date. If this and + --skip-precondition-checks are both set then it will not be verified that the + test case fails on this date.""", + type=lambda s: datetime.date.fromisoformat(s), + ) + container_search_args.add_argument( + "--start-date", + help=""" + Initial estimate of the latest nightly container date where the test case + passes. Defaults to the day before --end-date, but setting this to a date + further in the past may lead to faster convergence of the initial backwards + search for a date when the test case passed. If this and + --skip-precondition-checks are both set then the test case *must* pass on + this date, which will *not* be verified.""", + type=lambda s: datetime.date.fromisoformat(s), + ) + container_search_args.add_argument( + "--threshold-days", + default=1, + help=""" + Convergence threshold. Ideally, the container-level search will continue while + the number of days separating the last known success and first known failure is + smaller than this value. The minimum, and default, value is 1. Note that in + case of nightly build failures the search may finish without reaching this + threshold.""", + type=int, + ) + commit_search_args.add_argument( + "--bazel-cache", + default=os.path.join( + tempfile.gettempdir(), f"{getpass.getuser()}-bazel-triage-cache" + ), + help=""" + Bazel cache to use when [re-]building JAX/XLA during the fine search. This can + be a remote cache server or a local directory. Using a persistent cache can + significantly speed up the commit-level search. By default, uses a temporary + directory including the name of the current user.""", + ) + return parser.parse_args() diff --git a/.github/triage/jax_toolbox_triage/docker.py b/.github/triage/jax_toolbox_triage/docker.py new file mode 100644 index 000000000..85b21723b --- /dev/null +++ b/.github/triage/jax_toolbox_triage/docker.py @@ -0,0 +1,79 @@ +import logging +import pathlib +import subprocess +import typing + + +class DockerContainer: + def __init__( + self, + url: str, + *, + logger: logging.Logger, + mounts: typing.List[typing.Tuple[pathlib.Path, pathlib.Path]], + ): + self._logger = logger + self._mount_args = [] + for src, dst in mounts: + self._mount_args += ["-v", f"{src}:{dst}"] + self._url = url + + def __enter__(self): + result = subprocess.run( + [ + "docker", + "run", + "--detach", + # Otherwise bazel shutdown hangs. + "--init", + "--gpus=all", + "--shm-size=1g", + ] + + self._mount_args + + [ + self._url, + "sleep", + "infinity", + ], + check=True, + encoding="utf-8", + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + self._id = result.stdout.strip() + return self + + def __exit__(self, *exc_info): + subprocess.run( + ["docker", "stop", self._id], + check=True, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + def exec( + self, command: typing.List[str], workdir=None + ) -> subprocess.CompletedProcess: + """ + Run a command inside a persistent container. + """ + workdir = [] if workdir is None else ["--workdir", workdir] + return subprocess.run( + ["docker", "exec"] + workdir + [self._id] + command, + encoding="utf-8", + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + def check_exec( + self, cmd: typing.List[str], **kwargs + ) -> subprocess.CompletedProcess: + result = self.exec(cmd, **kwargs) + if result.returncode != 0: + self._logger.fatal( + f"{' '.join(cmd)} exited with return code {result.returncode}" + ) + self._logger.fatal(result.stdout) + self._logger.fatal(result.stderr) + result.check_returncode() + return result diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py new file mode 100644 index 000000000..1e88f57c5 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -0,0 +1,297 @@ +import datetime +import functools +import logging +import typing + + +def as_datetime(date: datetime.date) -> datetime.datetime: + return datetime.datetime.combine(date, datetime.time()) + + +def adjust_date( + date: datetime.datetime, + logger: logging.Logger, + container_exists: typing.Callable[[datetime.date], bool], + before: typing.Optional[datetime.date] = None, + after: typing.Optional[datetime.date] = None, + max_steps: int = 100, +) -> typing.Optional[datetime.date]: + """ + Given a datetime that may have non-zero hour/minute/second/... parts, and where + container_url(date.date()) might be a container that does not exist due to job + failure, return a similar date where container_url(new_date) does exist, or None if + no such container can be found. + + Arguments: + date: date to adjust + before: the returned date will be before this [optional] + after: the returned date will be after this [optional] + max_steps: maximum number of days away from the start date to venture + """ + round_up = date.time() > datetime.time(12) + down, up = (date.date(), -1), (date.date() + datetime.timedelta(days=1), +1) + options = [up, down] if round_up else [down, up] + n = 0 + while n < max_steps: + plausible_directions = 0 + for start, direction in options: + candidate = start + n * direction * datetime.timedelta(days=1) + if (before is None or candidate < before) and ( + after is None or candidate > after + ): + plausible_directions += 1 + if container_exists(candidate): + if date.date() != candidate: + logger.debug(f"Adjusted {date} to {candidate}") + return candidate + else: + logger.debug(f"{candidate} does not exist") + n += 1 + if plausible_directions == 0: + logger.info( + f"Could not adjust {date} given before={before} and after={after}" + ) + return None + logger.info(f"Could not find an adjusted {date} within {max_steps} steps") + return None + + +def container_search( + *, + container_exists: typing.Callable[[datetime.date], bool], + container_passes: typing.Callable[[datetime.date], bool], + start_date: typing.Optional[datetime.date], + end_date: typing.Optional[datetime.date], + logger: logging.Logger, + skip_precondition_checks: bool, + threshold_days: int, +): + adjust = functools.partial( + adjust_date, logger=logger, container_exists=container_exists + ) + # Figure out the end date of the search + if end_date is not None: + # --end-date was passed + if not container_exists(end_date): + raise Exception(f"--end-date={end_date} is not a valid container") + skip_end_date_check = skip_precondition_checks + else: + # Default to the most recent container + now = datetime.datetime.now() + end_date = adjust(now) + if end_date is None: + raise Exception(f"Could not find a valid container from {now}") + skip_end_date_check = False + + # Check preconditions; the test is supposed to fail on the end date. + if skip_end_date_check: + logger.info(f"Skipping check for end-of-range failure in {end_date}") + else: + logger.info(f"Checking end-of-range failure in {end_date}") + if container_passes(end_date): + raise Exception(f"Could not reproduce failure in {end_date}") + + # Start the coarse, container-level, search for a starting point to the bisection range + earliest_failure = end_date + if start_date is None: + # Start from the day before the end date. + search_date = adjust( + as_datetime(end_date) - datetime.timedelta(days=1), before=end_date + ) + if search_date is None: + raise Exception(f"Could not find a valid nightly before {end_date}") + logger.info( + f"Starting coarse search with {search_date} based on end_date={end_date}" + ) + # We just found a starting value, we need to actually check if the test passes or + # fails on it. + skip_first_phase = False + else: + # If a start value seed was given, use it. + if start_date >= end_date: + raise Exception(f"{start_date} must be before {end_date}") + if not container_exists(start_date): + raise Exception(f"--start-date={start_date} is not a valid container") + search_date = start_date + assert search_date is not None # for mypy + # If --skip-precondition-checks and --start-date are both passed, we assume that + # the test passed on the given --start-date and the first phase of the search can + # be skipped + skip_first_phase = skip_precondition_checks + if not skip_first_phase: + logger.info( + f"Starting coarse search with {search_date} based on --start-date" + ) + + if skip_first_phase: + logger.info(f"Skipping check that the test passes on start_date={start_date}") + else: + # While condition prints an info message + while not container_passes(search_date): + # Test failed on `search_date`, go further into the past + earliest_failure = search_date + new_search_date = adjust( + as_datetime(end_date) - 2 * (end_date - search_date), + before=search_date, + ) + if new_search_date is None: + raise Exception( + f"Could not find a passing nightly before {search_date}" + ) + search_date = new_search_date + + # Continue the container-level search, refining the range until it meets the criterion + # set by args.threshold_days. The test passed at range_start and not at range_end. + range_start, range_end = search_date, earliest_failure + logger.info( + f"Coarse container-level search yielded [{range_start}, {range_end}]..." + ) + while range_end - range_start > datetime.timedelta(days=threshold_days): + range_mid = adjust( + as_datetime(range_start) + 0.5 * (range_end - range_start), + before=range_end, + after=range_start, + ) + if range_mid is None: + # It wasn't possible to refine further. + break + result = container_passes(range_mid) + if result: + range_start = range_mid + else: + range_end = range_mid + logger.info(f"Refined container-level range to [{range_start}, {range_end}]") + return range_start, range_end + + +class BuildAndTest(typing.Protocol): + def __call__( + self, *, jax_commit: str, xla_commit: str + ) -> typing.Tuple[bool, str, str]: ... + + +def commit_search( + *, + jax_commits: typing.Sequence[typing.Tuple[str, datetime.datetime]], + xla_commits: typing.Sequence[typing.Tuple[str, datetime.datetime]], + build_and_test: BuildAndTest, + logger: logging.Logger, + skip_precondition_checks: bool, +): + """ + build_and_test: test the given commits in the container that originally shipped with end_{jax,xla}_commit. + """ + if ( + len(jax_commits) == 0 + or len(xla_commits) == 0 + or len(jax_commits) + len(xla_commits) < 3 + ): + raise Exception("Not enough commits passed") + start_jax_commit = jax_commits[0][0] + start_xla_commit = xla_commits[0][0] + end_jax_commit = jax_commits[-1][0] + end_xla_commit = xla_commits[-1][0] + if skip_precondition_checks: + logger.info("Skipping check that vanilla rebuild + test reproduces failure") + else: + # Verify we can build successfully and that the test fails as expected. These + # commits are the ones already checked out in the container, but specifying + # them explicitly is good for the summary JSON. + logger.info("Building in the range-ending container...") + range_end_result, stdout, stderr = build_and_test( + jax_commit=end_jax_commit, xla_commit=end_xla_commit + ) + if not range_end_result: + logger.info("Verified test failure after vanilla rebuild") + else: + logger.fatal("Vanilla rebuild did not reproduce test failure") + logger.fatal(stdout) + logger.fatal(stderr) + raise Exception("Could not reproduce") + + # Verify that we can build the commit at the start of the range and reproduce the + # test success there in the end-of-range container. + range_start_result, stdout, stderr = build_and_test( + jax_commit=start_jax_commit, xla_commit=start_xla_commit + ) + if range_start_result: + logger.info( + "Test passed after rebuilding commits from start container in end container" + ) + else: + logger.fatal( + "Test failed after rebuilding commits from start container in end container" + ) + logger.fatal(stdout) + logger.fatal(stderr) + raise Exception("Could not reproduce") + + # Finally, start bisecting. This is XLA-centric; JAX is moved too but is secondary. + while len(xla_commits) > 2: + middle = len(xla_commits) // 2 + xla_hash, xla_date = xla_commits[middle] + # Find the oldest JAX commit that is newer than this + for jax_index, (jax_hash, jax_date) in enumerate(jax_commits): + if jax_date >= xla_date: + break + bisect_result, _, _ = build_and_test(jax_commit=jax_hash, xla_commit=xla_hash) + if bisect_result: + # Test passed, continue searching in the second half + xla_commits = xla_commits[middle:] + jax_commits = jax_commits[jax_index:] + else: + # Test failed, continue searching in the first half + xla_commits = xla_commits[: middle + 1] + jax_commits = jax_commits[: jax_index + 1] + + # XLA bisection converged. xla_commits has two entries. jax_commits may be a little + # longer, if it was more active than XLA at the relevant time. For example, here + # xla_commits is {oX, nX} and jax_commits is {oJ, mJ, nJ}, and the test passes with + # {oX, oJ} and fails with {nX, nJ}. Naming: o=old, m=medium, n=new, X=XLA, J=JAX. + # pass fail + # XLA: oX -------- nX + # JAX: oJ -- mJ -- nJ + # + # To figure out whether to blame XLA or JAX, we now test {oX, nJ}. + old_xla_hash = xla_commits[0][0] + new_jax_hash = jax_commits[-1][0] + blame_result, _, _ = build_and_test( + jax_commit=new_jax_hash, xla_commit=old_xla_hash + ) + if blame_result: + # Test passed with {oX, nJ} but was known to fail with {nX, nJ}. Therefore, XLA + # commit nX is responsible and JAX is innocent. + results = (old_xla_hash, xla_commits[1][0]) + logger.info( + "Bisected failure to XLA {}..{} with JAX {}".format(*results, new_jax_hash) + ) + return { + "jax_ref": new_jax_hash, + "xla_bad": xla_commits[1][0], + "xla_good": old_xla_hash, + } + else: + # Test failed with {oX, nJ} but was known to pass with {oX, oJ}, so JAX is + # responsible and we should bisect between oJ (pass) and nJ (fail). This yields + # a single JAX commit to blame, either mJ or nJ in the example above. + while len(jax_commits) > 2: + middle = len(jax_commits) // 2 + jax_hash, _ = jax_commits[middle] + bisect_result, _, _ = build_and_test( + jax_commit=jax_hash, xla_commit=old_xla_hash + ) + if bisect_result: + # Test passsed, continue searching in second half + jax_commits = jax_commits[middle:] + else: + # Test failed, continue searching in the first half + jax_commits = jax_commits[: middle + 1] + results = (jax_commits[0][0], jax_commits[1][0]) + logger.info( + "Bisected failure to JAX {}..{} with XLA {}".format(*results, old_xla_hash) + ) + return { + "jax_bad": jax_commits[1][0], + "jax_good": jax_commits[0][0], + "xla_ref": old_xla_hash, + } diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py new file mode 100755 index 000000000..55af1d727 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/main.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +import datetime +import functools +import json +import logging +import time +import typing + +from .args import parse_args +from .docker import DockerContainer +from .logic import commit_search, container_search +from .utils import ( + container_exists as container_exists_base, + container_url as container_url_base, + get_logger, + prepare_bazel_cache_mounts, +) + + +def main(): + args = parse_args() + bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache) + logger = get_logger(args.output_prefix) + container_url = functools.partial(container_url_base, container=args.container) + container_exists = functools.partial( + container_exists_base, container=args.container, logger=logger + ) + Container = functools.partial( + DockerContainer, logger=logger, mounts=bazel_cache_mounts + ) + bazel_cache_mount_args = [] + for src, dst in bazel_cache_mounts: + bazel_cache_mount_args += ["-v", f"{src}:{dst}"] + + def add_summary_record(section, record, scalar=False): + """ + Add a record to the output JSON file. This is intended to provide a useful record + even in case of a fatal error. + """ + summary_filename = args.output_prefix / "summary.json" + try: + with open(summary_filename, "r") as ifile: + data = json.load(ifile) + except FileNotFoundError: + data = {} + if scalar: + if section in data: + logging.warning(f"Overwriting summary data in section {section}") + data[section] = record + else: + if section not in data: + data[section] = [] + data[section].append(record) + with open(summary_filename, "w") as ofile: + json.dump(data, ofile) + + def get_commit(container: DockerContainer, repo: str) -> typing.Tuple[str, str]: + """ + Get the commit of the given repository that was used in the given nightly container + + Arguments: + date: nightly container date + repo: repository, must be jax or xla + """ + assert repo in {"jax", "xla"} + # Older containers used /opt/jax-source etc. + for suffix in ["", "-source"]: + dirname = f"/opt/{repo}{suffix}" + result = container.exec(["git", "rev-parse", "HEAD"], workdir=dirname) + if result.returncode == 0: + commit = result.stdout.strip() + if len(commit) == 40: + return commit, dirname + raise Exception( + f"Could not extract commit of {repo} from {args.container} container {container}" + ) + + def check_container(date: datetime.date) -> bool: + """ + See if the test passes in the given container. + """ + before = time.monotonic() + with Container(container_url(date)) as worker: + result = worker.exec(args.test_command) + test_time = time.monotonic() - before + jax_commit = get_commit(worker, "jax") + xla_commit = get_commit(worker, "xla") + + logger.debug(result.stdout) + logger.info(f"Ran test case in {date} in {test_time:.1f}s") + test_pass = result.returncode == 0 + add_summary_record( + "container", + { + "container": container_url(date), + "jax": jax_commit, + "result": test_pass, + "test_time": test_time, + "xla": xla_commit, + }, + ) + return test_pass + + # Search through the published containers, narrowing down to a pair of dates with + # the property that the test passed on `range_start` and fails on `range_end`. + range_start, range_end = container_search( + container_exists=container_exists, + container_passes=check_container, + start_date=args.start_date, + end_date=args.end_date, + logger=logger, + skip_precondition_checks=args.skip_precondition_checks, + threshold_days=args.threshold_days, + ) + + # Container-level search is now complete. Triage proceeds inside the `range_end`` + # container. First, we check that rewinding JAX and XLA inside the `range_end`` + # container to the commits used in the `range_start` container passes, whereas + # using the `range_end` commits reproduces the failure. + + with Container(container_url(range_start)) as worker: + start_jax_commit, _ = get_commit(worker, "jax") + start_xla_commit, _ = get_commit(worker, "xla") + + # Fire up the container that will be used for the fine search. + with Container(container_url(range_end)) as worker: + end_jax_commit, jax_dir = get_commit(worker, "jax") + end_xla_commit, xla_dir = get_commit(worker, "xla") + logger.info( + ( + f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and " + f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}" + ) + ) + + # Get the full lists of JAX/XLA commits and dates + def commits(start, end, dir): + result = worker.check_exec( + [ + "git", + "log", + "--first-parent", + "--reverse", + "--format=%H %cI", + f"{start}^..{end}", + ], + workdir=dir, + ) + data = [] + for line in result.stdout.splitlines(): + commit, date = line.split() + date = datetime.datetime.fromisoformat(date).astimezone( + datetime.timezone.utc + ) + data.append((commit, date)) + return data + + # Get lists of (commit_hash, commit_date) pairs + jax_commits = commits(start_jax_commit, end_jax_commit, jax_dir) + xla_commits = commits(start_xla_commit, end_xla_commit, xla_dir) + # Confirm they're sorted by commit date + assert all(b[1] >= a[1] for a, b in zip(jax_commits, jax_commits[1:])) + assert all(b[1] >= a[1] for a, b in zip(xla_commits, xla_commits[1:])) + # Confirm the end values are included as expected + assert start_jax_commit == jax_commits[0][0] + assert start_xla_commit == xla_commits[0][0] + assert end_jax_commit == jax_commits[-1][0] + assert end_xla_commit == xla_commits[-1][0] + + def build_and_test( + jax_commit: str, xla_commit: str + ) -> typing.Tuple[bool, str, str]: + """ + The main body of the bisection loop. Update the JAX/XLA commits, build XLA and + jaxlib, and run the test command. Throws on error when checking out or + building, and returns the status of the test command. + """ + worker.check_exec(["git", "stash"], workdir=xla_dir) + worker.check_exec(["git", "stash"], workdir=jax_dir) + worker.check_exec(["git", "checkout", xla_commit], workdir=xla_dir) + worker.check_exec(["git", "checkout", jax_commit], workdir=jax_dir) + logger.info(f"Checking out XLA {xla_commit} JAX {jax_commit}") + # Build JAX + before = time.monotonic() + # Next two are workarounds for bugs in old containers + worker.check_exec(["sh", "-c", f"rm -vf {jax_dir}/dist/jaxlib-*.whl"]) + # This will error out on newer containers, but that should be harmless + worker.exec( + [ + "cp", + f"{jax_dir}/jax/version.py", + f"{jax_dir}/build/lib/jax/version.py", + ] + ) + # It seemed that this might be the origin of flaky behaviour. + worker.check_exec( + ["sh", "-c", "echo 'test --cache_test_results=no' > /root/.bazelrc"] + ) + build_jax = [ + "build-jax.sh", + # Leave the editable /opt/jax[-source] installation alone. Otherwise + # test-jax.sh is broken by having a /usr/... installation directory. + "--jaxlib_only", + # Workaround bugs in old containers where the default was wrong. + "--src-path-jax", + jax_dir, + f"--bazel-cache={args.bazel_cache}", + ] + worker.check_exec(build_jax, workdir=jax_dir) + middle = time.monotonic() + logger.info(f"Build completed in {middle - before:.1f}s") + # Run the test + test_result = worker.exec(args.test_command) + test_time = time.monotonic() - middle + add_summary_record( + "commit", + { + "build_time": middle - before, + "container": container_url(range_end), + "jax": jax_commit, + "result": test_result.returncode == 0, + "test_time": test_time, + "xla": xla_commit, + }, + ) + logger.info(f"Test completed in {test_time:.1f}s") + logger.debug( + f"Test stdout:\n{test_result.stdout}\nTest stderr:\n{test_result.stderr}" + ) + return test_result.returncode == 0, test_result.stdout, test_result.stderr + + # Run the commit-level bisection + result = commit_search( + jax_commits=jax_commits, + xla_commits=xla_commits, + build_and_test=build_and_test, + logger=logger, + skip_precondition_checks=args.skip_precondition_checks, + ) + result["container"] = container_url(range_end) + add_summary_record("result", result, scalar=True) diff --git a/.github/triage/jax_toolbox_triage/utils.py b/.github/triage/jax_toolbox_triage/utils.py new file mode 100644 index 000000000..b8a0639f8 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/utils.py @@ -0,0 +1,76 @@ +import datetime +import logging +import pathlib +import subprocess +import typing + + +def container_url(date: datetime.date, *, container: str) -> str: + """ + Construct the URL for --container on the given date. + + Arguments: + date: YYYY-MM-DD format. + """ + # Around 2024-02-09 the naming scheme changed. + if date > datetime.date(year=2024, month=2, day=9): + return f"ghcr.io/nvidia/jax:{container}-{date.isoformat()}" + else: + return f"ghcr.io/nvidia/{container}:nightly-{date.isoformat()}" + + +def container_exists( + date: datetime.date, *, container: str, logger: logging.Logger +) -> bool: + """ + Check if the given container exists. + """ + result = subprocess.run( + ["docker", "pull", container_url(date, container=container)], + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + encoding="utf-8", + ) + logger.debug(result.stdout) + return result.returncode == 0 + + +def get_logger(output_prefix: pathlib.Path) -> logging.Logger: + output_prefix.mkdir() + logger = logging.getLogger("triage") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + fmt="[%(levelname)s] %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + console = logging.StreamHandler() + trace_file = logging.FileHandler(filename=output_prefix / "info.log", mode="w") + debug_file = logging.FileHandler(filename=output_prefix / "debug.log", mode="w") + console.setLevel(logging.INFO) + trace_file.setLevel(logging.INFO) + debug_file.setLevel(logging.DEBUG) + console.setFormatter(formatter) + trace_file.setFormatter(formatter) + debug_file.setFormatter(formatter) + logger.addHandler(console) + logger.addHandler(trace_file) + logger.addHandler(debug_file) + return logger + + +def prepare_bazel_cache_mounts( + bazel_cache: str, +) -> typing.Sequence[typing.Tuple[pathlib.Path, pathlib.Path]]: + if ( + bazel_cache.startswith("http://") + or bazel_cache.startswith("https://") + or bazel_cache.startswith("grpc://") + ): + # Remote cache, no mount needed + return [] + elif (bazel_cache_path := pathlib.Path(bazel_cache)).is_absolute(): + bazel_cache_path.mkdir(exist_ok=True) + return [(bazel_cache_path, bazel_cache_path)] + else: + raise Exception( + "--bazel-cache should be an http/https/grpc URL or an absolute path" + ) diff --git a/.github/triage/pyproject.toml b/.github/triage/pyproject.toml new file mode 100644 index 000000000..b2d3c76ef --- /dev/null +++ b/.github/triage/pyproject.toml @@ -0,0 +1,8 @@ +[project] +name = "jax-toolbox-triage" +dynamic = ["version"] +# Because this script needs to run on compute clusters *outside* the containers that it +# orchestrates, it tries to tolerate old Python versions +requires-python = ">= 3.8" +[project.scripts] +jax-toolbox-triage = "jax_toolbox_triage:main" diff --git a/.github/triage/tests/test_triage_logic.py b/.github/triage/tests/test_triage_logic.py new file mode 100644 index 000000000..411740ded --- /dev/null +++ b/.github/triage/tests/test_triage_logic.py @@ -0,0 +1,387 @@ +import datetime +import itertools +import logging +import pytest +import random +from jax_toolbox_triage.logic import commit_search, container_search + + +def wrap(b): + return b, "", "" + + +@pytest.fixture +def logger(): + logger = logging.getLogger("triage-tests") + logger.setLevel(logging.DEBUG) + return logger + + +@pytest.mark.parametrize( + "dummy_test,expected", + [ + ( + lambda jax_commit, xla_commit: wrap(jax_commit == "oJ"), + {"xla_ref": "oX", "jax_bad": "mJ", "jax_good": "oJ"}, + ), + ( + lambda jax_commit, xla_commit: wrap(jax_commit != "nJ"), + {"xla_ref": "oX", "jax_bad": "nJ", "jax_good": "mJ"}, + ), + ( + lambda jax_commit, xla_commit: wrap(xla_commit == "oX"), + {"jax_ref": "nJ", "xla_bad": "nX", "xla_good": "oX"}, + ), + ], +) +def test_commit_search_explicit(logger, dummy_test, expected): + """ + Test the three possibilities in the hardcoded sequence below, where the container + level search yielded that (oJ, oX) passed and (nX, nJ) failed. mJ, nJ or nX could + be the culprit. + """ + jax_commits = [("oJ", 1), ("mJ", 2), ("nJ", 3)] + xla_commits = [("oX", 1), ("nX", 3)] + algorithm_result = commit_search( + build_and_test=dummy_test, + jax_commits=jax_commits, + logger=logger, + skip_precondition_checks=False, + xla_commits=xla_commits, + ) + assert algorithm_result == expected + + +start_date = datetime.datetime(2024, 10, 1) +step_size = datetime.timedelta(days=1) + + +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("extra_commits", [0, 2, 7, 100]) +def test_commit_search(logger, extra_commits, seed): + """ + Generate random sequences of JAX/XLA commits and test that the commit-level search + algorithm yields the expected results. + + The minimal set of commits generated is (good, bad) for one component and (ref) for + the other, where the test passes for (good, ref) and fails for (bad, ref). + + Around `extra_commits` extra commits will be added across the two components around + these. + """ + rng = random.Random(seed) + + def random_hash(): + return hex(int(rng.uniform(1e10, 9e10)))[2:] + + def random_delay(): + return rng.randint(1, 10) * step_size + + # Randomise whether JAX or XLA is newer at the start of the range + commits = { + "jax": [(random_hash(), start_date)], + "xla": [(random_hash(), start_date + rng.randint(-2, +2) * step_size)], + } + + def append_random_commits(n): + for _ in range(n): + output = rng.choice(list(commits.values())) + output.append((random_hash(), output[-1][1] + random_delay())) + + # Noise + append_random_commits(extra_commits // 2) + + # Inject the bad commit + culprit, innocent = rng.choice([("jax", "xla"), ("xla", "jax")]) + good_commit, good_date = commits[culprit][-1] + bad_commit, bad_date = random_hash(), good_date + random_delay() + assert good_date < bad_date + commits[culprit].append((bad_commit, bad_date)) + + # Noise + append_random_commits(extra_commits // 2) + + def dummy_test(*, jax_commit, xla_commit): + jax_date = {sha: dt for sha, dt in commits["jax"]}[jax_commit] + xla_date = {sha: dt for sha, dt in commits["xla"]}[xla_commit] + return wrap(xla_date < bad_date if culprit == "xla" else jax_date < bad_date) + + algorithm_result = commit_search( + build_and_test=dummy_test, + jax_commits=commits["jax"], + logger=logger, + skip_precondition_checks=False, + xla_commits=commits["xla"], + ) + # Do not check the reference commit, it's a bit underspecified quite what it means, + # other than that the dummy_test assertions below should pass + innocent_ref = algorithm_result.pop(f"{innocent}_ref") + assert { + f"{culprit}_bad": bad_commit, + f"{culprit}_good": good_commit, + } == algorithm_result + if culprit == "jax": + assert not dummy_test(jax_commit=bad_commit, xla_commit=innocent_ref)[0] + assert dummy_test(jax_commit=good_commit, xla_commit=innocent_ref)[0] + else: + assert not dummy_test(jax_commit=innocent_ref, xla_commit=bad_commit)[0] + assert dummy_test(jax_commit=innocent_ref, xla_commit=good_commit)[0] + + +def other(project): + return "xla" if project == "jax" else "jax" + + +def create_commits(num_commits): + """ + Generate commits for test_commit_search_exhaustive. + """ + + def fake_hash(): + fake_hash.n += 1 + return str(fake_hash.n) + + fake_hash.n = 0 + for first_project in ["jax", "xla"]: + for commit_types in itertools.product(range(3), repeat=num_commits - 1): + commits = [(first_project, fake_hash(), start_date)] + # Cannot have all commits from the same project + if sum(commit_types) == 0: + continue + for commit_type in commit_types: + prev_project, _, prev_date = commits[-1] + if commit_type == 0: # same + commits.append((prev_project, fake_hash(), prev_date + step_size)) + elif commit_type == 1: # diff + commits.append( + (other(prev_project), fake_hash(), prev_date + step_size) + ) + else: + assert commit_type == 2 # diff-concurrent + commits.append((other(prev_project), fake_hash(), prev_date)) + assert len(commits) == num_commits + + # The commits for a each project must have increasing timestamps + def increasing(project): + project_dates = list( + map(lambda t: t[2], filter(lambda t: t[0] == project, commits)) + ) + return all(x < y for x, y in zip(project_dates, project_dates[1:])) + + if not increasing("jax") or not increasing("xla"): + continue + + for bad_commit_index in range( + 1, # bad commit cannot be the first one + num_commits, + ): + bad_project, _, _ = commits[bad_commit_index] + # there must be a good commit before the last one + if not any( + project == bad_project + for project, _, _ in commits[:bad_commit_index] + ): + continue + yield bad_commit_index, commits + + +@pytest.mark.parametrize("commits", create_commits(5)) +def test_commit_search_exhaustive(logger, commits): + """ + Exhaustive search over combinations of a small number of commits + """ + bad_index, merged_commits = commits + bad_project, bad_commit, bad_date = merged_commits[bad_index] + good_project = other(bad_project) + split_commits = { + p: [(commit, date) for proj, commit, date in merged_commits if proj == p] + for p in ("jax", "xla") + } + good_commit, _ = list( + filter(lambda t: t[1] < bad_date, split_commits[bad_project]) + )[-1] + # in this test, there are no commit collisions + dates = {commit: date for _, commit, date in merged_commits} + assert all(len(v) for v in split_commits.values()) + assert len(split_commits[bad_project]) >= 2 + + def dummy_test(*, jax_commit, xla_commit): + return wrap( + dates[jax_commit if bad_project == "jax" else xla_commit] < bad_date + ) + + algorithm_result = commit_search( + build_and_test=dummy_test, + jax_commits=split_commits["jax"], + logger=logger, + skip_precondition_checks=False, + xla_commits=split_commits["xla"], + ) + # Do not check the reference commit, it's a bit underspecified quite what it means. + assert algorithm_result[f"{bad_project}_bad"] == bad_commit + assert algorithm_result[f"{bad_project}_good"] == good_commit + # Do check that the reference commit gives the expected results + assert not dummy_test( + **{ + f"{bad_project}_commit": bad_commit, + f"{good_project}_commit": algorithm_result[f"{good_project}_ref"], + } + )[0] + assert dummy_test( + **{ + f"{bad_project}_commit": good_commit, + f"{good_project}_commit": algorithm_result[f"{good_project}_ref"], + } + )[0] + + +@pytest.mark.parametrize( + "jax_commits,xla_commits", + [ + ([], [("", start_date)]), + ([("", start_date)], []), + ([("", start_date)], [("", start_date)]), + ], +) +def test_commit_search_no_commits(logger, jax_commits, xla_commits): + with pytest.raises(Exception, match="Not enough commits"): + commit_search( + build_and_test=lambda jax_commit, xla_commit: None, + jax_commits=jax_commits, + logger=logger, + skip_precondition_checks=False, + xla_commits=xla_commits, + ) + + +@pytest.mark.parametrize("value", [True, False]) +def test_commit_search_static_test_function(logger, value): + with pytest.raises(Exception, match="Could not reproduce"): + commit_search( + build_and_test=lambda jax_commit, xla_commit: wrap(value), + jax_commits=[("", start_date), ("", start_date + step_size)], + xla_commits=[("", start_date), ("", start_date + step_size)], + logger=logger, + skip_precondition_checks=False, + ) + + +far_future = datetime.date(year=2100, month=1, day=1) +further_future = datetime.date(year=2100, month=1, day=12) +assert far_future > datetime.date.today() +assert further_future > far_future +good_date = datetime.date(year=2100, month=1, day=1) +bad_date = datetime.date(year=2100, month=1, day=12) + + +@pytest.mark.parametrize( + "start_date,end_date,dates_that_exist,match_string", + [ + # Explicit start_date is later than explicit end date + ( + further_future, + far_future, + {far_future, further_future}, + "2100-01-12 must be before 2100-01-01", + ), + # Good order, but both invalid + (far_future, further_future, {}, "is not a valid container"), + # Good order, one invalid + (far_future, further_future, {far_future}, "is not a valid container"), + (far_future, further_future, {further_future}, "is not a valid container"), + # Valid end_date, but there are no valid earlier ones to be found + (None, far_future, {far_future}, "Could not find a valid nightly before"), + # Start from today, nothing valid to be found + (None, None, {}, "Could not find a valid container from"), + # Valid start, default end will not work + (far_future, None, {far_future}, "Could not find a valid container from"), + ], +) +def test_container_search_limits( + logger, start_date, end_date, dates_that_exist, match_string +): + """ + Test for failure if an invalid date is explicitly passed. + """ + with pytest.raises(Exception, match=match_string): + container_search( + container_exists=lambda dt: dt in dates_that_exist, + container_passes=lambda dt: False, + start_date=start_date, + end_date=end_date, + logger=logger, + skip_precondition_checks=False, + threshold_days=1, + ) + + +@pytest.mark.parametrize( + "start_date,end_date,dates_that_pass,match_string", + [ + # Test never passes + pytest.param( + far_future, + further_future, + {}, + "Could not find a passing nightly before", + marks=pytest.mark.xfail( + reason="No cutoff implemented if all dates exist but none pass" + ), + ), + # Test passes at the end of the range but not the start + ( + far_future, + further_future, + {further_future}, + "Could not reproduce failure in", + ), + # Test passes at both ends of the range + ( + far_future, + further_future, + {far_future, further_future}, + "Could not reproduce failure in", + ), + ], +) +def test_container_search_checks( + logger, start_date, end_date, dates_that_pass, match_string +): + """ + Test for failure if start/end dates are given that do not meet the preconditions. + """ + with pytest.raises(Exception, match=match_string): + container_search( + container_exists=lambda dt: True, + container_passes=lambda dt: dt in dates_that_pass, + start_date=start_date, + end_date=end_date, + logger=logger, + skip_precondition_checks=False, + threshold_days=1, + ) + + +@pytest.mark.parametrize("start_date", [None, datetime.date(year=2024, month=1, day=1)]) +@pytest.mark.parametrize( + "days_of_failure", [1, 2, 17, 19, 32, 64, 71, 113, 128, 256, 359] +) +@pytest.mark.parametrize("threshold_days", [1, 4, 15]) +def test_container_search(logger, start_date, days_of_failure, threshold_days): + end_date = datetime.date(year=2024, month=12, day=31) + one_day = datetime.timedelta(days=1) + threshold_date = end_date - days_of_failure * one_day + assert start_date is None or threshold_date >= start_date + good_date, bad_date = container_search( + container_exists=lambda dt: True, + container_passes=lambda dt: dt < threshold_date, + start_date=start_date, + end_date=end_date, + logger=logger, + skip_precondition_checks=False, + threshold_days=threshold_days, + ) + assert bad_date != good_date + assert bad_date - good_date <= datetime.timedelta(days=threshold_days) + assert good_date < threshold_date + assert bad_date >= threshold_date diff --git a/.github/workflows/triage-ci.yaml b/.github/workflows/triage-ci.yaml new file mode 100644 index 000000000..2fb92be1a --- /dev/null +++ b/.github/workflows/triage-ci.yaml @@ -0,0 +1,84 @@ +name: jax-toolbox-triage pure-Python CI + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +on: + pull_request: + types: + - opened + - reopened + - ready_for_review + - synchronize + paths-ignore: + - '**.md' + push: + branches: + - main + +env: + TRIAGE_PYTHON_FILES: .github/triage + +jobs: + mypy: + runs-on: ubuntu-24.04 + steps: + - name: Check out the repository under ${GITHUB_WORKSPACE} + uses: actions/checkout@v4 + with: + sparse-checkout: .github/triage + - name: "Setup Python 3.12" + uses: actions/setup-python@v5 + with: + python-version: '3.12' + - name: "Install mypy" + run: pip install mypy pytest + - name: "Run mypy checks" + shell: bash -x -e {0} + run: | + mypy .github/triage + pytest: + runs-on: ubuntu-24.04 + strategy: + matrix: + PYTHON_VERSION: ["3.8", "3.12"] + fail-fast: false + steps: + - name: Check out the repository under ${GITHUB_WORKSPACE} + uses: actions/checkout@v4 + with: + sparse-checkout: .github/triage + - name: "Setup Python ${{ matrix.PYTHON_VERSION}}" + uses: actions/setup-python@v5 + with: + python-version: '${{ matrix.PYTHON_VERSION }}' + - name: "Install jax-toolbox-triage" + run: pip install pytest .github/triage + - name: "Run tests" + shell: bash -x -e {0} + run: | + pytest .github/triage/tests + ruff: + runs-on: ubuntu-24.04 + steps: + - name: Check out the repository under ${GITHUB_WORKSPACE} + uses: actions/checkout@v4 + with: + sparse-checkout: .github/triage + - name: "Setup Python 3.12" + uses: actions/setup-python@v5 + with: + python-version: '3.12' + - name: "Install ruff" + run: pip install ruff + - name: "Run ruff checks" + shell: bash -x {0} + run: | + ruff check .github/triage + check_status=$? + ruff format --check .github/triage + format_status=$? + if [[ $format_status != 0 || $check_status != 0 ]]; then + exit 1 + fi diff --git a/docs/triage.md b/docs/triage.md index 27e10c617..9ec5eafe5 100644 --- a/docs/triage.md +++ b/docs/triage.md @@ -4,6 +4,12 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed, but this automates the investigation of questions like "what state of library X works with Jax at state Y?" +__Note__: There is also a utility, [triage](../.github/triage/triage), which can be +used for more granular bisection of failures in specific tests. Run it with `--help` +for usage instructions. Given a test expression that can be run inside the nightly +containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly +container where the failure first appeared, and second attributes the failure to a +specific commit of JAX or XLA. ## Algorithm The pseudocode for the triaging algorithm is as follows: