diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 7e037ea92..023576cb5 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -139,6 +139,7 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH} ############################################################################### ADD install-nsight.sh /usr/local/bin +ADD nsys-2024.5-tid-export.patch /opt/nvidia RUN install-nsight.sh ############################################################################### @@ -180,6 +181,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH} ADD install-nccl-sanity-check.sh /usr/local/bin ADD nccl-sanity-check.cu /opt RUN install-nccl-sanity-check.sh +ADD jax-nccl-test parallel-launch /usr/local/bin ############################################################################### ## Add the systemcheck to the entrypoint. @@ -203,7 +205,7 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/ ADD nsys-jax nsys-jax-combine /usr/local/bin/ ADD jax_nsys/ /opt/jax_nsys -ADD requirements-nsys-jax.in /opt/pip-tools.d/ +RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/ ############################################################################### diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index c85bee347..726656a7a 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip # TransformerEngine now needs JAX at build time git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} pushd ${SRC_PATH_TRANSFORMER_ENGINE} +export NVTE_BUILD_THREADS_PER_JOB=8 python setup.py bdist_wheel && rm -rf build ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist" EOF diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index fa4c055b8..8ff65ca99 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL # jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib pip list | grep jax +# Ensure directories are readable by all for non-root users +chmod 755 $BUILD_PATH_JAXLIB/* + ## Cleanup pushd $SRC_PATH_JAX diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index c0207c6b7..73aee4163 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -12,24 +12,15 @@ export DEBIAN_FRONTEND=noninteractive export TZ=America/Los_Angeles apt-get update -# TODO: revert to nsight-systems-cli instead of explicitly pinning -apt-get install -y nsight-compute nsight-systems-cli-2024.4.1 +apt-get install -y nsight-compute nsight-systems-cli apt-get clean rm -rf /var/lib/apt/lists/* -# "Wrong event order has been detected when adding events to the collection" -# workaround during nsys report post-processing with 2024.1.1 and CUDA 12.3 -NSYS202411=/opt/nvidia/nsight-systems-cli/2024.1.1 -if [[ "${UBUNTU_ARCH}" == "amd64" && -d "${NSYS202411}" ]]; then - LIBCUPTI123=/opt/nvidia/nsight-compute/2023.3.0/host/target-linux-x64/libcupti.so.12.3 - if [[ ! -f "${LIBCUPTI123}" ]]; then - echo "2024.1.1 workaround expects to be running inside CUDA 12.3 container" - exit 1 - fi - # Use libcupti.so.12.3 because this is a CUDA 12.3 container - ln -s "${LIBCUPTI123}" "${NSYS202411}/target-linux-x64/libcupti.so.12.3" - mv "${NSYS202411}/target-linux-x64/libcupti.so.12.4" "${NSYS202411}/target-linux-x64/_libcupti.so.12.4" +NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1 +if [[ -d "${NSYS202451}" ]]; then + # * can match at least sbsa-armv8 and x86 + (cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) fi # Install extra dependencies needed for `nsys recipe ...` commands. These are diff --git a/.github/container/jax-nccl-test b/.github/container/jax-nccl-test new file mode 100755 index 000000000..706713baf --- /dev/null +++ b/.github/container/jax-nccl-test @@ -0,0 +1,253 @@ +#!/usr/bin/env python +import argparse +from ctypes import byref, cdll, c_int, POINTER +from functools import partial +import jax +from jax.experimental.multihost_utils import sync_global_devices +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import os +import time + + +libcudart = cdll.LoadLibrary("libcudart.so") +cudaGetDeviceCount = libcudart.cudaGetDeviceCount +cudaGetDeviceCount.argtypes = [POINTER(c_int)] +cudaGetDeviceCount.restype = c_int +cudaProfilerStart = libcudart.cudaProfilerStart +cudaProfilerStop = libcudart.cudaProfilerStop + + +def visible_device_count() -> int: + """ + Query the number of local devices visible to this process. + """ + count = c_int() + assert cudaGetDeviceCount(byref(count)) == 0 + return count.value + + +def int_or_env(value) -> int: + try: + return int(value) + except ValueError: + return int(os.environ[value]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Pure-JAX implementation of a NCCL performance test" + ) + parser.add_argument( + "--coordinator-address", + help="Distributed coordinator address:port; used if --distributed is passed.", + ) + parser.add_argument( + "--distributed", + action="store_true", + help="Run jax.distributed.initialize()", + ) + parser.add_argument( + "--gpus-per-process", + help=( + "Number of GPUs driven by each controller process. " + "Defaults to 1 with --distributed and all of them otherwise." + ), + type=int, + ) + parser.add_argument( + "--process-count", + help=( + "When --distributed is passed this gives the total number of processes. " + "This can either be an integer of the name of an environment variable." + ), + type=int_or_env, + ) + parser.add_argument( + "--process-id", + help=( + "When --distributed is passed this gives the global index of this process." + "This can either be an integer or the name of an environment variable." + ), + type=int_or_env, + ) + args = parser.parse_args() + + assert ( + args.process_id is None or args.distributed + ), "--process-id is only relevant with --distributed" + if args.distributed: + null_args = { + args.coordinator_address is None, + args.gpus_per_process is None, + args.process_count is None, + args.process_id is None, + } + if all(null_args): + # Use default behaviour + jax.distributed.initialize() + else: + assert not any(null_args), ( + "All of --coordinator-address, --gpus-per-process, --process-count and " + "--process-id must be passed if any of them are." + ) + visible_devices = visible_device_count() + local_processes, rem = divmod(visible_devices, args.gpus_per_process) + assert rem == 0, ( + f"--gpus-per-process={args.gpus_per_process} does not divide the " + "visible device count {visible_devices}" + ) + # assume processes within a node are globally numbered contiguously + local_process_id = args.process_id % local_processes + first_local_device = local_process_id * args.gpus_per_process + local_device_ids = list( + range(first_local_device, first_local_device + args.gpus_per_process) + ) + print( + f"Rank {args.process_id} has local rank {local_process_id} and " + f"devices {local_device_ids} from a total of {visible_devices} " + f"visible on this node, {args.process_count} processes and " + f"{args.process_count*args.gpus_per_process} total devices.", + flush=True, + ) + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + local_device_ids=local_device_ids, + num_processes=args.process_count, + process_id=args.process_id, + ) + elif args.gpus_per_process is not None: + # Respect --gpus-per-process even without --distributed + jax.config.update( + "jax_cuda_visible_devices", + ",".join(str(x) for x in range(args.gpus_per_process)), + ) + + if jax.process_index() == 0: + print(f"JAX devices: {jax.devices()}") + n_devices = jax.device_count() + assert ( + args.gpus_per_process is None + or jax.local_device_count() == args.gpus_per_process + ), ( + f"Got {jax.local_device_count()} local devices despite " + f"--gpus-per-process={args.gpus_per_process}" + ) + mesh = Mesh(jax.devices(), axis_names=("i",)) + min_size_power = 0 + max_size_power = 30 + max_elements = 2**32 + sharding = partial( + shard_map, + mesh=mesh, + in_specs=(P("i"), P("i", None), None), + check_rep=False, + out_specs=P("i"), + ) + + @partial(jax.jit, static_argnames="collective") + @sharding + def measure_collective(sync, big_input, collective): + with jax.named_scope(collective): + output = 1.0 + big_input = big_input * jax.lax.psum(sync, "i") + assert big_input.shape == (1, 2**max_size_power), big_input.shape + for size in range(max_size_power + 1): + values_per_device = 2**size + input = output * jax.lax.slice( + big_input, (0, 0), (1, values_per_device) + ) + assert input.shape == (1, values_per_device), input.shape + result = None + # Trigger the collective we want to measure + if collective == "all_gather": + if input.size * n_devices < max_elements: + result = jax.lax.all_gather(input, "i") + assert result.shape == (n_devices, *input.shape), result.shape + elif collective == "all_reduce": + if input.size < max_elements: + result = jax.lax.psum(input, "i") + assert result.shape == (1, values_per_device), result.shape + elif collective == "broadcast": + if input.size < max_elements: + # FIXME: need https://github.com/google/jax/pull/20705 re-land + result = jax.lax.pbroadcast(input, "i", 0) + assert result.shape == (1, values_per_device), result.shape + elif collective == "permute": + if input.size < max_elements: + # TODO: make this sensitive to whether the permutation does or + # does not cross NVLink domain boundaries + permutation = [ + (i, (i + 1) % n_devices) for i in range(n_devices) + ] + result = jax.lax.ppermute(input, "i", permutation) + assert result.shape == (1, values_per_device), result.shape + else: + assert collective == "reduce_scatter", collective + if values_per_device >= n_devices: + # Need to be able to scatter at least 1 value of the result on + # each device. This results in the largest message size (NCCL + # convention) for reduce-scatter being a factor `n_devices` + # smaller than the other collectives + result = jax.lax.psum_scatter( + input, "i", scatter_dimension=1, tiled=True + ) + assert result.shape == ( + 1, + values_per_device // n_devices, + ), result.shape + # Do something with the results to stop them getting combined/removed + if result is not None: + output *= 1.5 + jnp.tanh(jnp.mean(result)) # scale by [0.5, 1.5] + return jnp.array([output]) + + def measure(sync, input, host_timer=False): + for op in ["all_gather", "all_reduce", "permute", "reduce_scatter"]: + start = time.time() + result = measure_collective(sync, input, op) + if host_timer: + result.block_until_ready() + if jax.process_index() == 0: + print(f"First {op} duration {time.time()-start:.2f}s") + return result + + def device_put_local(x: jax.Array): + return [jax.device_put(x, d) for d in jax.local_devices()] + + # This helper is used to trigger a small barrier before the main measurement, again + # to improve measurement quality. It's always the same and is sharded with one + # value per device. + sync = jax.make_array_from_single_device_arrays( + (n_devices,), + NamedSharding(mesh, P("i")), + device_put_local(jnp.ones((1,))), + ) + input = jax.make_array_from_single_device_arrays( + (n_devices, 2**max_size_power), + NamedSharding(mesh, P("i")), + device_put_local(jax.random.normal(jax.random.key(1), (1, 2**max_size_power))), + ) + if jax.process_index() == 0: + print(f"Data for pre-measurement synchronisation {sync.shape}") + jax.debug.visualize_array_sharding(sync) + print(f"Data for collective measurements {input.shape}") + jax.debug.visualize_array_sharding(input) + + start = time.time() + sync_global_devices("init") + sync_time = time.time() - start + if jax.process_index() == 0: + print(f"Barrier time (NCCL init): {sync_time:.2f}s") + + measure(sync, input, host_timer=True) + sync_global_devices("warmup_done") + cudaProfilerStart() + sync_global_devices("profiling_started") + for _ in range(10): + measure(sync, input) + sync_global_devices("measurements_completed") + cudaProfilerStop() + sync_global_devices("profiling_ended") + if jax.process_index() == 0: + print("Exiting...") diff --git a/.github/container/jax_nsys/Analysis.ipynb b/.github/container/jax_nsys/Analysis.ipynb index 82e32263d..3224c940b 100644 --- a/.github/container/jax_nsys/Analysis.ipynb +++ b/.github/container/jax_nsys/Analysis.ipynb @@ -649,7 +649,10 @@ "source": [ "if len(steady_state.communication):\n", " fig, grid = plt.subplots(\n", - " nrows=len(top_module_ids), figsize=[15, 5], squeeze=False, tight_layout=True\n", + " nrows=len(top_module_ids),\n", + " figsize=[15, 5 * len(top_module_ids)],\n", + " squeeze=False,\n", + " tight_layout=True,\n", " )\n", " time_df = steady_state.thunk.loc[\n", " ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n", diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 8af610425..9e3aaee4f 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -58,6 +58,8 @@ def align_profiler_data_timestamps( # Apply these corrections to the device-side timestamps for k in ["communication", "module", "thunk"]: df = getattr(frames, k) + if df is None: + continue df["ProjStartMs"] -= median_device_skews setattr(frames, k, df) return frames, { diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index 5180366bc..6c25cb2ee 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -131,10 +131,13 @@ def _load_nvtx_gpu_proj_trace_single( "Children Count": "NumChild", "Range ID": "RangeId", "Parent ID": "ParentId", + "PID": "PID", "Range Stack": "RangeStack", "Stack Level": "Lvl", + "TID": "TID", } if set(df.columns) == alt_rename_map.keys(): + tsl_prefix = "" df = df.rename( columns={k: v for k, v in alt_rename_map.items() if v is not None} ) @@ -145,12 +148,25 @@ def _load_nvtx_gpu_proj_trace_single( ) # TODO: add OrigDurMs, OrigStartMs else: + tsl_prefix = "TSL:" + df = df.drop(columns=["Style"]) df["OrigDurMs"] = 1e-6 * df.pop("Orig Duration") df["OrigStartMs"] = 1e-6 * df.pop("Orig Start") df["ProjDurMs"] = 1e-6 * df.pop("Projected Duration") df["ProjStartMs"] = 1e-6 * df.pop("Projected Start") df = df.dropna(subset=["RangeId"]) - df = df.set_index(df.pop("RangeId").astype(np.int32), verify_integrity=True) + try: + df = df.set_index(df.pop("RangeId").astype(np.int32), verify_integrity=True) + except ValueError: + print( + "A duplicate key related error may indicate that you are using " + "Nsight Systems 2024.5 and have CUDA graphs enabled; as noted on " + "https://github.com/NVIDIA/JAX-Toolbox/blob/main/docs/profiling.md " + "you may want to disable CUDA graphs by adding " + "--xla_gpu_enable_command_buffer= to the XLA_FLAGS environment " + "variable." + ) + raise # Due to idiosyncracies of how Nsight tracks CUDA graphs, and because # thunks can be nested, the NVTX hierarchy generally looks like: # Iteration -> XlaModule:A [-> XlaModule:B] -> Thunk:C [-> Thunk:D ...] @@ -160,16 +176,16 @@ def _load_nvtx_gpu_proj_trace_single( # Get all of the Thunks in the profile. Note that we want to discard some # of these, like Thunk:C in the example above, for not being the most # deeply nested. - thunk_prefix = "TSL:Thunk:#" + thunk_prefix = f"{tsl_prefix}Thunk:#" all_thunks = df["Name"].str.startswith(thunk_prefix) # If profile collection started while an XlaModule was executing, there may # be Thunk ranges without XlaModule parents. We treat those as edge effects # and ignore them. - module_prefix = "TSL:XlaModule:" + module_prefix = f"{tsl_prefix}XlaModule:" all_modules = df["Name"].str.startswith(module_prefix) - first_module_orig_time = df.loc[all_modules, "OrigStartMs"].min() - thunks_without_modules = all_thunks & (df["OrigStartMs"] < first_module_orig_time) + first_module_start_time = df.loc[all_modules, "ProjStartMs"].min() + thunks_without_modules = all_thunks & (df["ProjStartMs"] < first_module_start_time) if thunks_without_modules.sum(): print(f"Ignoring {thunks_without_modules.sum()} thunks without modules") all_thunks &= ~thunks_without_modules @@ -221,20 +237,23 @@ def _load_nvtx_gpu_proj_trace_single( # XlaModule calculate the mean and standard deviation of the number of GPU # operations in all but the last occurence, and see if the last occurence # is an outlier. TODO: if we processed the SQLite database directly, we - # would know if the current XlaModule range had actually been closed. - for mod_name, mod_name_df in df.loc[mod_ids, :].groupby("Name"): - gpu_ops = mod_name_df["NumGPUOps"].array - not_last, last = gpu_ops[:-1], gpu_ops[-1] - if last < np.mean(not_last) - np.std(not_last): - print( - "Skipping last occurence of {} because it only had {} GPU operations, compared to {} +/- {} before".format( - mod_name, last, np.mean(not_last), np.std(not_last) + # would know if the current XlaModule range had actually been closed. TODO: + # provide an implementation that works with the 2024.5 output format. + if "NumGPUOps" in df.columns: + for mod_name, mod_name_df in df.loc[mod_ids, :].groupby("Name"): + gpu_ops = mod_name_df["NumGPUOps"].array + not_last, last = gpu_ops[:-1], gpu_ops[-1] + if last < np.mean(not_last) - np.std(not_last): + print( + "Skipping last occurence of {} because it only had {} GPU operations, compared to {} +/- {} before".format( + mod_name, last, np.mean(not_last), np.std(not_last) + ) ) - ) - mod_id = mod_name_df.index[-1] - mod_ids.remove(mod_id) - # Also remove its thunks from all_thunks - all_thunks &= df["ModuleId"] != mod_id + mod_id = mod_name_df.index[-1] + mod_ids.remove(mod_id) + # Also remove its thunks from all_thunks + all_thunks &= df["ModuleId"] != mod_id + df = df.drop(columns=["NumGPUOps"]) # Parse the numerical program ID out of the name of each XlaModule. # program_id is not set in all cases, although this could be fixed in XLA. @@ -243,7 +262,11 @@ def _load_nvtx_gpu_proj_trace_single( # propagated to the GpuExecutable that emits the XlaModule annotation. # Those are probably not interesting, so setting the ProgramId to -1 in # such cases is acceptable. - module_re = r"^TSL:XlaModule:#(?:prefix=(.*?),|)hlo_module=([a-z0-9._-]+)(?:,program_id=(\d+)|)#$" + module_re = ( + "^" + + tsl_prefix + + r"XlaModule:#(?:prefix=(.*?),|)hlo_module=([a-z0-9._-]+)(?:,program_id=(\d+)|)#$" + ) mod_program_ids = ( df.loc[mod_ids, "Name"] .str.replace( @@ -307,11 +330,9 @@ def clean_data_frame(d): "Lvl", "ModuleId", "NumChild", - "NumGPUOps", "ParentId", "PID", "RangeStack", - "Style", "TID", ] ).astype({"ProgramExecution": np.int32, "ProgramId": np.int32}) @@ -321,7 +342,7 @@ def clean_data_frame(d): # At this point there should be no need to look beyond the rows for individual thunks + the protobuf data, and we can further clean up the data thunk_df = clean_data_frame(df[all_thunks]) thunk_df["Name"] = thunk_df["Name"].replace( - to_replace="^TSL:Thunk:#(?:name=(.*?),|)hlo_op=([a-z0-9._-]+)#$", + to_replace=f"^{tsl_prefix}Thunk:#(?:name=(.*?),|)hlo_op=([a-z0-9._-]+)#$", value=r"\2", regex=True, ) diff --git a/.github/container/jax_nsys/python/jax_nsys/pyproject.toml b/.github/container/jax_nsys/python/jax_nsys/pyproject.toml index cc3f3981a..4c5ca9600 100644 --- a/.github/container/jax_nsys/python/jax_nsys/pyproject.toml +++ b/.github/container/jax_nsys/python/jax_nsys/pyproject.toml @@ -5,6 +5,9 @@ dependencies = [ "ipython", "numpy", "pandas", + "protobuf", # a compatible version of protoc needs to be installed out-of-band "pyarrow", + "requests", # for install-protoc + "uncertainties", # communication analysis recipe ] requires-python = ">= 3.10" diff --git a/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py b/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py new file mode 100644 index 000000000..ef02f5c1b --- /dev/null +++ b/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +import argparse +from collections import defaultdict +from jax_nsys import ( + align_profiler_data_timestamps, + apply_warmup_heuristics, + ensure_compiled_protos_are_importable, + load_profiler_data, +) +from math import sqrt +import pathlib +from uncertainties import ufloat # type: ignore + + +def main(): + parser = argparse.ArgumentParser( + description="Summarise communication in an nsys-jax report" + ) + parser.add_argument("prefix", type=pathlib.Path) + args = parser.parse_args() + # Make sure that the .proto files under protos/ have been compiled to .py, and + # that those generated .py files are importable. + ensure_compiled_protos_are_importable(prefix=args.prefix) + # Load the profiler data; the compilation part is needed for the warmup heuristics + all_data = load_profiler_data(args.prefix, frames={"communication", "compile"}) + # Align timestamps + all_data, alignment_metadata = align_profiler_data_timestamps(all_data) + # TODO: make this pretty + # print(alignment_metadata) + # Partition the profile data into initialisation and steady-state running + _, steady_state = apply_warmup_heuristics(all_data) + assert len(steady_state.communication), ( + "Communication summary was requested but no steady-state communication was " + "identified." + ) + collective_types = set() + summary_data = defaultdict(dict) + for (collective, message_size), df in steady_state.communication.groupby( + ["Collective", "MessageSize"] + ): + collective_types.add(collective) + # This grouped data frame will have a row for each device that is participating + # in this instance of the collective. + devices = df.groupby(["ProgramId", "ProgramExecution", "ThunkIndex"]) + # Take the fastest device bandwidth. Rationale: the slower devices appear + # slower because they spend some time waiting for the last device, and then all + # devices complete the collective at the same time. The fastest device is + # therefore the last one to join the collective and its bandwidth estimate does + # not contain a wait time component. The .mean() is over the different + # (ProgramId, ProgramExecution, ThunkIndex) values. + bandwidth = devices["BusBandwidthGBPerSec"].agg("max") + summary_data[message_size][collective] = ufloat( + bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth)) + ) + collective_types = sorted(collective_types) + collective_widths = { + collective: max( + len(collective), + max( + len(f"{data[collective]:S}") + for data in summary_data.values() + if collective in data + ), + ) + for collective in collective_types + } + size_heading = "Size [B]" + size_width = max(len(size_heading), max(len(f"{s:,}") for s in summary_data.keys())) + print(f"{'':<{size_width}} | Bus bandwidth [GB/s]") + print( + " | ".join( + [f"{size_heading:<{size_width}}"] + + [f"{coll:<{collective_widths[coll]}}" for coll in collective_types] + ) + ) + + def format_message_size(message_size): + return f"{message_size:<{size_width},}" + + def format_bandwidth(data, collective): + width = collective_widths[collective] + if collective not in data: + return "-" * width + return f"{data[collective]:>{width}S}" + + for message_size in sorted(summary_data.keys()): + data = summary_data[message_size] + print( + " | ".join( + [format_message_size(message_size)] + + [ + format_bandwidth(data, collective) + for collective in collective_types + ] + ) + ) + + +if __name__ == "__main__": + main() diff --git a/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py b/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py index e7168998b..978c041fa 100755 --- a/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py +++ b/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py @@ -50,48 +50,48 @@ def dump(fname, df): df.to_json(ofile, orient="split") dump("module-stats", module_stats) + print(f" === MODULE EXECUTION SUMMARY ===\n{module_stats}") compilation_stats = generate_compilation_statistics(init.compile) - total_compile_time = compilation_stats["DurNonChildMs"].sum() - compilation_stats["DurNonChildPercent"] = ( - 100 * compilation_stats["DurNonChildMs"] / total_compile_time - ) - # Dump before dropping - dump("compilation-ranges", compilation_stats) - compilation_stats = compilation_stats.drop(columns=["DurChildMs"]) - top_n = 10 - top_n_ranges = compilation_stats.iloc[:top_n] - - # All XlaPass ranges combined into a single XlaPasses range, XlaPassPipeline ranges ignored - def remove_xlapass_xlapasspipeline_detail(name): - if name.startswith("XlaPass:#"): - return "XlaPasses" - elif name.startswith("XlaPassPipeline:#"): - return "XlaPassPipelines" - else: - return name + if len(compilation_stats): + total_compile_time = compilation_stats["DurNonChildMs"].sum() + compilation_stats["DurNonChildPercent"] = ( + 100 * compilation_stats["DurNonChildMs"] / total_compile_time + ) + # Dump before dropping + dump("compilation-ranges", compilation_stats) + compilation_stats = compilation_stats.drop(columns=["DurChildMs"]) + top_n = 10 + top_n_ranges = compilation_stats.iloc[:top_n] - no_pass_detail = ( - compilation_stats.groupby(remove_xlapass_xlapasspipeline_detail) - .agg("sum") - .sort_values("DurNonChildMs", ascending=False) - ) - dump("compilation-high-level", no_pass_detail) - # Top few passes, with the percentages re-scaled to be relative to XlaPasses above - pass_df = compilation_stats[ - compilation_stats.index.to_series().str.startswith("XlaPass:#") - ] - pass_df["DurNonChildPercent"] = ( - 100 * pass_df["DurNonChildMs"] / pass_df["DurNonChildMs"].sum() - ) - dump("compilation-passes", pass_df) + # All XlaPass ranges combined into a single XlaPasses range, XlaPassPipeline ranges ignored + def remove_xlapass_xlapasspipeline_detail(name): + if name.startswith("XlaPass:#"): + return "XlaPasses" + elif name.startswith("XlaPassPipeline:#"): + return "XlaPassPipelines" + else: + return name - print(f" === MODULE EXECUTION SUMMARY ===\n{module_stats}") - print(f" === COMPILATION TIME -- TOP {top_n} RANGES ===\n{top_n_ranges}") - print(f" === COMPILATION TIME -- NO PASS DETAIL ===\n{no_pass_detail}") - print( - f" === COMPILATION TIME -- TOP {top_n} XLA PASSES ===\n{pass_df.iloc[:top_n]}" - ) + no_pass_detail = ( + compilation_stats.groupby(remove_xlapass_xlapasspipeline_detail) + .agg("sum") + .sort_values("DurNonChildMs", ascending=False) + ) + dump("compilation-high-level", no_pass_detail) + # Top few passes, with the percentages re-scaled to be relative to XlaPasses above + pass_df = compilation_stats[ + compilation_stats.index.to_series().str.startswith("XlaPass:#") + ] + pass_df["DurNonChildPercent"] = ( + 100 * pass_df["DurNonChildMs"] / pass_df["DurNonChildMs"].sum() + ) + dump("compilation-passes", pass_df) + print(f" === COMPILATION TIME -- TOP {top_n} RANGES ===\n{top_n_ranges}") + print(f" === COMPILATION TIME -- NO PASS DETAIL ===\n{no_pass_detail}") + print( + f" === COMPILATION TIME -- TOP {top_n} XLA PASSES ===\n{pass_df.iloc[:top_n]}" + ) if __name__ == "__main__": diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 419d3bfd0..60ef1a001 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -49,8 +49,7 @@ praxis: patches: pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. pull/36/head: file://patches/praxis/PR-36.patch # adds Transformer Engine support - pull/74/head: file://patches/praxis/PR-74.patch # experimental support for using TE FMHA in GQA - pull/83/head: file://patches/praxis/PR-83.patch # Fix unindex tuple error for PP + pull/84/head: file://patches/praxis/PR-84.patch # experimental support for using TE FMHA in GQA lingvo: # Used only in ARM pax builds url: https://github.com/tensorflow/lingvo.git diff --git a/.github/container/nsys-2024.5-tid-export.patch b/.github/container/nsys-2024.5-tid-export.patch new file mode 100644 index 000000000..f19d35e27 --- /dev/null +++ b/.github/container/nsys-2024.5-tid-export.patch @@ -0,0 +1,24 @@ +diff --git a/nsys_recipe/lib/nvtx.py b/nsys_recipe/lib/nvtx.py +index 2470043..7abf892 100644 +--- a/nsys_recipe/lib/nvtx.py ++++ b/nsys_recipe/lib/nvtx.py +@@ -161,6 +161,7 @@ def _compute_gpu_projection_df( + "start": list(nvtx_gpu_start_dict.values()) + starts, + "end": list(nvtx_gpu_end_dict.values()) + ends, + "pid": nvtx_df.loc[list(nvtx_gpu_end_dict.keys()) + indices, "pid"], ++ "tid": nvtx_df.loc[list(nvtx_gpu_end_dict.keys()) + indices, "tid"], + } + ) + +diff --git a/nsys_recipe/recipes/nvtx_gpu_proj_trace/nvtx_gpu_proj_trace.py b/nsys_recipe/recipes/nvtx_gpu_proj_trace/nvtx_gpu_proj_trace.py +index cd60bf4..37e0d0d 100644 +--- a/nsys_recipe/recipes/nvtx_gpu_proj_trace/nvtx_gpu_proj_trace.py ++++ b/nsys_recipe/recipes/nvtx_gpu_proj_trace/nvtx_gpu_proj_trace.py +@@ -96,6 +96,7 @@ class NvtxGpuProjTrace(recipe.Recipe): + "start": "Start", + "end": "End", + "pid": "PID", ++ "tid": "TID", + "stackLevel": "Stack Level", + "childrenCount": "Children Count", + "rangeId": "Range ID", diff --git a/.github/container/nsys-jax b/.github/container/nsys-jax index 4b25debe2..104732306 100755 --- a/.github/container/nsys-jax +++ b/.github/container/nsys-jax @@ -16,7 +16,6 @@ import subprocess import sys import tempfile import time -import virtualenv # type: ignore import zipfile @@ -362,12 +361,11 @@ def compress_and_archive(prefix, file, output_queue): output_queue.put((file + ".xz", lzma.compress(ifile.read()), compress_none)) -def run_nsys_stats_report(report, report_file, tmp_dir, output_queue, wait_on): +def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): """ Run a stats recipe on an .nsys-rep file (that has probably already been exported to .sqlite). """ - wait_on.result() start = time.time() subprocess.run( [ @@ -377,8 +375,11 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue, wait_on): report, "--input", report_file, + # avoid race conditions with other reports/etc. + "--sqlite", + osp.splitext(report_file)[0] + "-" + report + ".sqlite", "--output", - ".", + osp.join(tmp_dir, "report"), ] + (["--force-overwrite"] if nsys_force_overwrite else []), check=True, @@ -388,15 +389,29 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue, wait_on): print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") -def save_device_stream_thread_names(tmp_dir, report, output_queue, wait_on): +def save_device_stream_thread_names(tmp_dir, report, output_queue): """ Extract extra information from the SQLite dump that is needed to map projected NVTX ranges to global device IDs. """ - wait_on.result() start = time.time() assert report.endswith(".nsys-rep"), f"{report} had an unexpected suffix" - db_file = report.removesuffix(".nsys-rep") + ".sqlite" + db_file = report.removesuffix(".nsys-rep") + "-metadata.sqlite" + subprocess.run( + [ + "nsys", + "export", + "--type", + "sqlite", + "--tables", + "StringIds,TARGET_INFO_GPU,TARGET_INFO_NVTX_CUDA_DEVICE,TARGET_INFO_SYSTEM_ENV,ThreadNames", + "--output", + db_file, + report, + ], + check=True, + ) + assert os.path.exists(db_file) con = sqlite3.connect(db_file) cur = con.cursor() @@ -528,18 +543,6 @@ def execute_analysis_scripts(mirror_dir, analysis_scripts): return [], 0 assert mirror_dir is not None - venv_dir = osp.join(mirror_dir, "venv") - virtualenv.cli_run([venv_dir, "--python", sys.executable, "--system-site-packages"]) - subprocess.run( - [ - osp.join(venv_dir, "bin", "pip"), - "--disable-pip-version-check", - "install", - "-e", - osp.join(mirror_dir, "python", "jax_nsys"), - ], - check=True, - ) output = [] exit_code = 0 used_slugs = set() @@ -558,7 +561,7 @@ def execute_analysis_scripts(mirror_dir, analysis_scripts): candidates = list(filter(osp.exists, search)) assert len(candidates), f"Could not find analysis script, tried {search}" args.append(mirror_dir) - analysis_command = [osp.join(venv_dir, "bin", "python"), candidates[0]] + args + analysis_command = [sys.executable, candidates[0]] + args # Derive a unique name slug from the analysis script name slug = osp.basename(candidates[0]).removesuffix(".py") n, suffix = 1, "" @@ -736,14 +739,15 @@ with ThreadPoolExecutor() as executor, output_thread(executor): ) ) # Convert .nsys-rep -> .parquet and queue the latter for archival - gpu_proj_future = executor.submit( - run_nsys_recipe, - "nvtx_gpu_proj_trace", - tmp_rep, - tmp_dir, - files_to_archive, + futures.append( + executor.submit( + run_nsys_recipe, + "nvtx_gpu_proj_trace", + tmp_rep, + tmp_dir, + files_to_archive, + ) ) - futures.append(gpu_proj_future) # Copy /opt/jax_nsys into the archive futures.append( executor.submit(copy_jax_nsys_files, "/opt/jax_nsys", files_to_archive) @@ -764,8 +768,6 @@ with ThreadPoolExecutor() as executor, output_thread(executor): futures, ) ) - # Don't run this in parallel with gpu_proj_future because the two recipes - # implicitly create the same .sqlite export on demand. futures.append( executor.submit( run_nsys_stats_report, @@ -773,7 +775,6 @@ with ThreadPoolExecutor() as executor, output_thread(executor): tmp_rep, tmp_dir, files_to_archive, - gpu_proj_future, # for dependency purposes only ) ) # Do some custom post-processing of the .sqlite export generated by gpu_proj_future @@ -783,7 +784,6 @@ with ThreadPoolExecutor() as executor, output_thread(executor): tmp_dir, tmp_rep, files_to_archive, - gpu_proj_future, # for dependency purposes only ) ) # Wait for errors/completion of `futures`; note that this does not include diff --git a/.github/container/nsys-jax-combine b/.github/container/nsys-jax-combine index 911866f52..00e53efb6 100755 --- a/.github/container/nsys-jax-combine +++ b/.github/container/nsys-jax-combine @@ -4,7 +4,11 @@ from collections import defaultdict import copy import os import pathlib +import shlex import shutil +import subprocess +import sys +import tempfile import zipfile parser = argparse.ArgumentParser( @@ -15,6 +19,45 @@ parser = argparse.ArgumentParser( "of an application, checking consistency and removing duplicated data." ), ) +parser.add_argument( + "--analysis", + action="append", + help=( + "Post-processing analysis script to execute after merging. This can be the " + "name of a recipe bundled in the inpit files, or the path to a Python script. " + "The script will be passed any arguments specified via --analysis-arg, " + "followed by a single positional argument, which is the path to a directory " + "of the same structure as the extracted output archive." + ), + type=lambda x: ("script", x), +) +parser.add_argument( + "--analysis-arg", + action="append", + dest="analysis", + help="Extra arguments to pass to analysis scripts specified via --analysis", + type=lambda x: ("arg", x), +) + + +def shuffle_analysis_arg(analysis): + if analysis is None: + return [] + # [Script(A), Arg(A1), Arg(A2), Script(B), Arg(B1)] becomes [[A, A1, A2], [B, B1]] + out, current = [], [] + for t, x in analysis: + if t == "script": + if len(current): + out.append(current) + current = [x] + else: + assert t == "arg" and len(current) + current.append(x) + if len(current): + out.append(current) + return out + + parser.add_argument( "-f", "--force-overwrite", @@ -52,6 +95,7 @@ parser.add_argument( ) # TODO: derive a default output path from the input paths args = parser.parse_args() +args.analysis = shuffle_analysis_arg(args.analysis) if args.output.suffix != ".zip": args.output = args.output.with_suffix(".zip") if os.path.exists(args.output) and not args.force_overwrite: @@ -65,6 +109,7 @@ for input in args.input: for member in ifile.infolist(): hashes[member.filename].add(member.CRC) +mirror_dir = pathlib.Path(tempfile.mkdtemp()) if len(args.analysis) else None with zipfile.ZipFile(args.output, "w") as ofile: for n_input, input in enumerate(args.input): first_input = n_input == 0 @@ -81,8 +126,15 @@ with zipfile.ZipFile(args.output, "w") as ofile: def write(dst_info): assert dst_info.filename not in set(ofile.namelist()) - with ifile.open(member) as src, ofile.open(dst_info, "w") as dst: - shutil.copyfileobj(src, dst) + with ifile.open(member) as src: + with ofile.open(dst_info, "w") as dst: + shutil.copyfileobj(src, dst) + if mirror_dir is not None: + dst_path = mirror_dir / dst_info.filename + os.makedirs(dst_path.parent, exist_ok=True) + src.seek(0) + with open(dst_path, "wb") as dst: + shutil.copyfileobj(src, dst) if filename.endswith(".nsys-rep"): assert len(seen_hashes) == 1 @@ -104,3 +156,44 @@ with zipfile.ZipFile(args.output, "w") as ofile: dst_info = copy.copy(member) dst_info.filename = filename + "/" + input.stem write(dst_info) + if len(args.analysis): + assert mirror_dir is not None + used_slugs = set() + for analysis in args.analysis: + # Execute post-processing recipes and add any outputs to `ofile` + script, script_args = analysis[0], analysis[1:] + # If --analysis is the name of a bundled analysis script, use that. Otherwise it should be a file that exists. + search = [ + mirror_dir / "python" / "jax_nsys_analysis" / (script + ".py"), + pathlib.Path(script), + ] + candidates = list(filter(lambda p: p.exists(), search)) + assert len(candidates), f"Could not find analysis script, tried {search}" + analysis_command = ( + [sys.executable, candidates[0]] + script_args + [mirror_dir] + ) + # Derive a unique name slug from the analysis script name + slug = os.path.basename(candidates[0]).removesuffix(".py") + n, suffix = 1, "" + while slug + suffix in used_slugs: + suffix = f"-{n}" + n += 1 + slug += suffix + used_slugs.add(slug) + working_dir = mirror_dir / "analysis" / slug + os.makedirs(working_dir, exist_ok=True) + print( + f"Running analysis script: {shlex.join(map(str, analysis_command))} in {working_dir}" + ) + subprocess.run( + analysis_command, + check=True, + cwd=working_dir, + ) + # Gather output files of the scrpt + for path in working_dir.rglob("*"): + with open(working_dir / path, "rb") as src, ofile.open( + str(path.relative_to(mirror_dir)), "w" + ) as dst: + # https://github.com/python/mypy/issues/15031 ? + shutil.copyfileobj(src, dst) # type: ignore diff --git a/.github/container/parallel-launch b/.github/container/parallel-launch new file mode 100755 index 000000000..e238e2aeb --- /dev/null +++ b/.github/container/parallel-launch @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +if (( $# < 3 )); then + echo "Usage: $0 " + echo "launches command 'num_processes' times in parallel with 'var_name' set to 0..num_processes-1" + exit 1 +fi +VAR_NAME=$1 +shift +NPROCS=$1 +shift +positive_integer='^[1-9][0-9]*$' +if ! [[ $NPROCS =~ $positive_integer ]]; then + echo "Second argument must be a positive number of processes; got $NPROCS" + exit 1 +fi +pids=() +echo "Launching $@ $NPROCS times in parallel" +for (( i=0; i<$NPROCS; i++ )); do + export $VAR_NAME=$i + "$@" & + pids+=($!) +done +for pid in ${pids[*]}; do + wait $pid +done +jobs diff --git a/.github/container/pip-finalize.sh b/.github/container/pip-finalize.sh index 4828af35e..1149d7638 100755 --- a/.github/container/pip-finalize.sh +++ b/.github/container/pip-finalize.sh @@ -51,6 +51,6 @@ pip-sync --pip-args '--no-deps --src /opt' requirements.txt rm -rf ~/.cache/* -# protobuf will be installed at least due to requirements-nsys-jax.in in the base +# protobuf will be installed at least as a dependency of jax_nsys in the base # image, but the installed version is likely to be influenced by other packages. install-protoc /usr/local diff --git a/.github/container/requirements-nsys-jax.in b/.github/container/requirements-nsys-jax.in deleted file mode 100644 index 223881f6d..000000000 --- a/.github/container/requirements-nsys-jax.in +++ /dev/null @@ -1,5 +0,0 @@ -# No version constraint; a compatible version of protoc will be installed later -protobuf -# Used by install-protoc -requests -virtualenv diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 21591c91c..0dc26c8c1 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -223,18 +223,16 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true +export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true - --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization} @@ -268,4 +266,4 @@ fi echo "Command: python3 $RUN_SETTINGS" python3 $RUN_SETTINGS -echo "Output at ${OUTPUT}" \ No newline at end of file +echo "Output at ${OUTPUT}" diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 350d64685..426764323 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -305,6 +305,84 @@ jobs: test-gpu.log secrets: inherit + test-nsys-jax: + needs: build-jax + if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + uses: ./.github/workflows/_test_unit.yaml + with: + TEST_NAME: nsys-jax + EXECUTE: | + set -o pipefail + num_tests=0 + num_failures=0 + GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') + for mode in 1-process 2-process process-per-gpu; do + DOCKER="docker run --shm-size=1g --gpus all --env XLA_FLAGS=--xla_gpu_enable_command_buffer= --env XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -v ${PWD}:/opt/output ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}" + if [[ "${mode}" == "1-process" ]]; then + PROCESS_COUNT=1 + ARGS="" + elif [[ "${mode}" == "2-process" ]]; then + # Use two processes with GPUS_PER_NODE/2 GPUs per process in the hope that + # this will flush out more bugs than process-per-node or process-per-GPU. + PROCESS_COUNT=2 + ARGS="--process-id RANK --process-count ${PROCESS_COUNT} --coordinator-address 127.0.0.1:12345 --gpus-per-process $((GPUS_PER_NODE/2)) --distributed" + else + PROCESS_COUNT=${GPUS_PER_NODE} + ARGS="--process-id RANK --process-count ${PROCESS_COUNT} --coordinator-address 127.0.0.1:12345 --gpus-per-process 1 --distributed" + fi + for collection in full partial; do + NSYS_JAX="nsys-jax" + if [[ "${mode}" == "1-process" ]]; then + # We will not run nsys-jax-combine, so run analyses eagerly + NSYS_JAX+=" --nsys-jax-analysis communication --nsys-jax-analysis summary" + fi + NSYS_JAX+=" --output=/opt/output/${mode}-${collection}-execution-%q{RANK}" + if [[ "${collection}" == "partial" ]]; then + NSYS_JAX+=" --capture-range=cudaProfilerApi --capture-range-end=stop" + # nvbug/4801401 + NSYS_JAX+=" --sample=none" + fi + set +e + ${DOCKER} parallel-launch RANK ${PROCESS_COUNT} ${NSYS_JAX} \ + -- jax-nccl-test ${ARGS} |& tee ${mode}-${collection}-execution.log + num_failures=$((num_failures + ($? != 0))) + set -e + num_tests=$((num_tests + 1)) + done + if [[ "${mode}" != "1-process" ]]; then + # Run nsys-jax-combine + NSYS_JAX_COMBINE="nsys-jax-combine --analysis communication --analysis summary --output=/opt/output/${mode}-${collection}-execution.zip" + for (( i=0; i> $GITHUB_ENV + echo "NSYS_JAX_FAIL_COUNT=${num_failures}" >> $GITHUB_ENV + exit $num_failures + STATISTICS_SCRIPT: | + num_passed=$(( NSYS_JAX_TEST_COUNT - NSYS_JAX_FAIL_COUNT )) + num_errors=0 + echo "TOTAL_TESTS=${NSYS_JAX_TEST_COUNT}" >> $GITHUB_OUTPUT + echo "ERRORS=0" >> $GITHUB_OUTPUT + echo "PASSED_TESTS=${num_passed}" >> $GITHUB_OUTPUT + echo "FAILED_TESTS=${NSYS_JAX_FAIL_COUNT}" >> $GITHUB_OUTPUT + ARTIFACTS: | + # nsys-jax logfiles + *process-*-execution.log + # nsys-jax output for the case that doesn't use nsys-jax-combine + 1-process-*-execution-0.zip + # nsys-jax-combine output/logfiles + *process*-*-execution.zip + *-execution-combine.log + secrets: inherit + # test-equinox: # needs: build-equinox # if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a diff --git a/.github/workflows/baselines/test_maxtext_metrics.py b/.github/workflows/baselines/test_maxtext_metrics.py index bd180ecfe..a130c86c6 100644 --- a/.github/workflows/baselines/test_maxtext_metrics.py +++ b/.github/workflows/baselines/test_maxtext_metrics.py @@ -19,7 +19,7 @@ def test_loss(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: end_step = json.load(baseline_file)["end_step"] @@ -31,7 +31,7 @@ def test_loss(baseline_filename): def test_step_time(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: step_time_avg_expected = json.load(baseline_file)["step_time_avg"] diff --git a/.github/workflows/nsys-jax.yaml b/.github/workflows/nsys-jax.yaml index 3b7d1e6e6..aa91f870c 100644 --- a/.github/workflows/nsys-jax.yaml +++ b/.github/workflows/nsys-jax.yaml @@ -22,6 +22,7 @@ env: JAX-Toolbox/.github/container/nsys-jax JAX-Toolbox/.github/container/nsys-jax-combine JAX-Toolbox/.github/container/jax_nsys + JAX-Toolbox/.github/container/jax-nccl-test jobs: mypy: @@ -37,18 +38,11 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.10' - - name: "Create virtual environment" - run: | - pip install virtualenv - virtualenv venv - - name: "Install google.protobuf and protoc" - run: | - ./venv/bin/pip install -r ./JAX-Toolbox/.github/container/requirements-nsys-jax.in - ./venv/bin/python ./JAX-Toolbox/.github/container/jax_nsys/install-protoc ./venv - - name: "Install jax_nsys Python package" - run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys - - name: "Install mypy" - run: ./venv/bin/pip install matplotlib mypy nbconvert types-protobuf + # jax is just a CPU-only build of the latest release for type-checking purposes + - name: "Install jax / jax_nsys / mypy" + run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf + - name: "Install protoc" + run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc - name: "Fetch XLA .proto files" uses: actions/checkout@v4 with: @@ -63,19 +57,19 @@ jobs: mkdir compiled_protos compiled_stubs protos mv -v xla/third_party/tsl/tsl protos/ mv -v xla/xla protos/ - ./venv/bin/python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" + PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" touch compiled_stubs/py.typed - name: "Convert .ipynb to .py" shell: bash -x -e {0} run: | for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do - ./venv/bin/jupyter nbconvert --to script ${notebook} + jupyter nbconvert --to script ${notebook} done - name: "Run mypy checks" shell: bash -x -e {0} run: | export MYPYPATH="${PWD}/compiled_stubs" - ./venv/bin/mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} + mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} # Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered # notebook from here too. These input files were generated with something like @@ -89,25 +83,42 @@ jobs: steps: - name: Check out the repository under ${GITHUB_WORKSPACE} uses: actions/checkout@v4 - - name: Use nsys-jax-combine to merge profiles from multiple nsys processes - shell: bash -x -e {0} - # TODO: when nsys-jax-combine supports --nsys-jax-analysis, exercise that here. - run: .github/container/nsys-jax-combine -o .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip - - name: Mock up the structure of an extracted .zip file - run: unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip - name: "Setup Python 3.12" uses: actions/setup-python@v5 with: python-version: '3.12' + # TODO: a modern nsys-jax-combine with old .zip input should probably produce a + # .zip with a modern jax_nsys/ + - name: Add modern jax_nsys/ files to static .zip inputs + run: | + cd .github/container/jax_nsys + for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do + zip -ur "${zip}" . + zipinfo "${zip}" + done + - name: Use nsys-jax-combine to merge profiles from multiple nsys processes + shell: bash -x -e {0} + run: | + pip install -e .github/container/jax_nsys/python/jax_nsys + python .github/container/jax_nsys/install-protoc local_protoc + PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \ + --analysis summary \ + --analysis communication \ + -o pax_fsdp4_4proc.zip \ + .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip + - name: Extract the output .zip file + run: | + mkdir combined/ + unzip -d combined/ pax_fsdp4_4proc.zip - name: Run the install script, but skip launching Jupyter Lab shell: bash -x -e {0} run: | pip install virtualenv - NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh + NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh - name: Test the Jupyter Lab installation and execute the notebook shell: bash -x -e {0} run: | - pushd .github/container/jax_nsys + pushd combined/ ./nsys_jax_venv/bin/python -m jupyterlab --version # Run with ipython for the sake of getting a clear error message ./nsys_jax_venv/bin/ipython Analysis.ipynb diff --git a/README.md b/README.md index af1f0f36b..1764c5f00 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,13 @@ +
+ + + + + + @@ -270,7 +277,7 @@ We currently support the following frameworks and models. More details about eac | :--- | :---: | :---: | :---: | | [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | | [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | -| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02` | +| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | | [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | | levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | | maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` | @@ -293,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33). +For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index c5456e3c4..fabbc6963 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to - --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX) - --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging) +## Previously used XLA Flags - +The following flags were used previously used but no longer required. +- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed +- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default +- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index cb26ea2df..069b06fdd 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -111,14 +111,11 @@ Enabling this feature is effortless. Users only need to include the option `--fd In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like: ```bash export XLA_FLAGS=" \ - --xla_gpu_enable_reduction_epilogue_fusion=false \ --xla_gpu_enable_triton_gemm=false \ - --xla_gpu_enable_cudnn_fmha=false \ - --xla_gpu_enable_cudnn_layer_norm=true \ - --xla_gpu_enable_cublaslt=true \ - --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_enable_highest_priority_async_stream=true \ - --xla_gpu_all_reduce_combine_threshold_bytes=51200 " + --xla_gpu_enable_pipelined_all_reduce=false \ + --xla_gpu_enable_pipelined_all_gather=false \ + --xla_gpu_enable_pipelined_reduce_scatter=false \ +" export ENABLE_TE=0 python -m paxml.main \ ... @@ -126,8 +123,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. - +Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 86882dd37..2425ddffe 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -63,7 +63,6 @@ In order to get the best performance with PGLE, here is a list of all recommende export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 ---xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824 @@ -71,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true ---xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization diff --git a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh index fce89244e..a5eaf9aa0 100644 --- a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh +++ b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh @@ -1 +1,2 @@ -export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute ${XLA_FLAGS}" +# These XLA flags are meant to be used with the JAX version in the imagen container +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" diff --git a/rosetta/rosetta/projects/imagen/README.md b/rosetta/rosetta/projects/imagen/README.md index 4959a4118..136f913d6 100644 --- a/rosetta/rosetta/projects/imagen/README.md +++ b/rosetta/rosetta/projects/imagen/README.md @@ -17,7 +17,7 @@ For maximum flexibility and low disk requirements, this repo supports a **distri We provide [scripts](scripts) to run [interactively](scripts/singlenode_inf_train.sh) or on [SLURM](scripts/example_slurm_inf_train.sub). ### Container -We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02`. +We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3`. We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates! Imagen will also be available in our T5x container in future releases. @@ -37,7 +37,7 @@ You will need to acquire the LLM checkpoint for T5 (for multimodal training) fro **Note**: this should only be done with singlenode jobs ```bash -CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02 +CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bash ``` @@ -99,15 +99,27 @@ sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in [scripts](scripts). Prompts should be specified as in [example](../diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt) #### Sampling 256x256 images -Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script) +Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script, e.g., [imagen_256_sample_2b.gin](configs/imagen_256_sample_2b.gin)). ``` -CUDA_VISIBLE_DEVICES= CFG=5.0 BASE_PATH= SR1_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_256.sh +CUDA_VISIBLE_DEVICES= CFG=5.0 GLOBAL_BATCH_SIZE= GEN_PER_PROMPT=1 BASE_PATH= SR1_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_256.sh +``` + +Here is an example: +``` +# Note: +# - the quoting of double quotes wrapping single quotes is necessary. +# - BASE_PATH/SR1_PATH are checkpoint dirs, and are expected to contain a `checkpoint` file, e.g., the file $BASE_PATH/checkpoint should exist +# - GLOBAL_BATCH_SIZE should be set with number of GPUs in mind. For instance GLOBAL_BATCH_SIZE >= num gpus, +# to ensure at least one example is sent to each GPU. +# - Currently there is a limitation where the number of lines in PROMPT_TEXT_FILES should be divisible by the number of GPUs. +# The easiest way to ensure that is just to pad the files with dummy prompts until it is divisible +CUDA_VISIBLE_DEVICES=0,1 CFG=5.0 GLOBAL_BATCH_SIZE=4 GEN_PER_PROMPT=1 BASE_PATH='"/mnt/imagen_ckpt/checkpoint_585000"' SR1_PATH='"/mnt/sr1_ckpt/checkpoint_5000"' PROMPT_TEXT_FILES='"./rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt"' ./rosetta/projects/imagen/scripts/sample_imagen_256.sh ``` #### Sampling 1024x1024 images -Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script). +Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script, e.g., [imagen_1024_sample_2b.gin](configs/imagen_1024_sample_2b.gin)). ``` -CUDA_VISIBLE_DEVICES= CFG=5.0 BASE_PATH= SR1_PATH= SR2_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh +CUDA_VISIBLE_DEVICES= CFG=5.0 GLOBAL_BATCH_SIZE= GEN_PER_PROMPT=1 BASE_PATH= SR1_PATH= SR2_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh ``` diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin b/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin new file mode 100644 index 000000000..101e2773e --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin @@ -0,0 +1,78 @@ +# Imagen Sampling pipeline +include "rosetta/projects/imagen/configs/imagen_256_sample_2b.gin" + +from __gin__ import dynamic_registration +import __main__ as sample_script +from t5x import gin_utils +from t5x import utils +from t5x import partitioning + +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +#---------------- SR1024 Model ------------------------------------------------- + +# ------------------- Model ---------------------------------------------------- +SR1024 = @sr1024/models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +sr1024/models.DenoisingDiffusionModel: + denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= None + diffusion_sampler= @sr1024/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +sr1024/denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @sr1024/network_sr.ImagenEfficientUNet() + +sr1024/samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig() +sr1024/network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 8) + width_multipliers = (1, 2, 4, 6, 6) + attn_resolutions_divs = {16: 'cross'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +sr1024/samplers.CFGSamplingConfig: + num_steps=30 + cf_guidance_weight=0.0 + cf_guidance_nulls={'text': None, 'text_mask': None} + +sr1024/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +sr1024/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +sr1024/sample_script.DiffusionModelSetupData: + model = %SR1024 + sampling_cfg = @sr1024/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} + +sample_script.sample: + sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData() diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin b/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin new file mode 100644 index 000000000..13272ebbe --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin @@ -0,0 +1,220 @@ +# Imagen Sampling pipeline +from __gin__ import dynamic_registration + +import __main__ as sample_script +from t5x import gin_utils +from t5x import utils +from t5x import partitioning + +SAVE_DIR='generations' +PROMPT_TEXT_FILE='custom_text.txt' +GLOBAL_BATCH_SIZE=32 +MAX_GENERATE=50000000 +GEN_PER_PROMPT=2 +NOISE_COND_AUG=0.002 + +TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim +TXT_SEQLEN=(1, 128, ) +TXT_SEQLEN_SINGLE=128 +DTYPE='bfloat16' +DROPOUT_RATE=0 +RESUME_FROM=0 #Sampling count to resume from +#---------------- Base Model ------------------------------------------------- +from rosetta.projects.imagen import network +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +# ------------------- Model ---------------------------------------------------- +BASE = @base_model/models.DenoisingDiffusionModel() +base_model/models.DenoisingDiffusionModel: + denoiser= @base_model/denoisers.EDMTextConditionedDenoiser() + diffusion_loss = None + diffusion_sampler= @base_model/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +base_model/denoisers.EDMTextConditionedDenoiser: + raw_model= @base_model/network.ImagenUNet() + +# ------------------- Network specification ------------------------------------ +base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig() +base_model/network.DiffusionConfig: + dtype = %DTYPE + model_dim = 512 + attn_cond_dim = 2048 + cond_dim = 2048 + resblocks_per_level = 3 + width_multipliers = (1, 2, 3, 4) + attn_resolutions = (32, 16, 8) + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + dropout_rate = %DROPOUT_RATE + upsample_mode = 'shuffle' + downsample_mode = 'shuffle' + spatial_skip = False + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits = False + text_conditionable = True + + +BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig() +base_model/samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=5.00 + cf_guidance_nulls=None + +base_model/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +base_model/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +base_model/sample_script.DiffusionModelSetupData: + model = %BASE + sampling_cfg = @base_model/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'} + +#---------------- SR256 Model ------------------------------------------------- + +# ------------------- Model ---------------------------------------------------- +SR256 = @sr256/models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +sr256/models.DenoisingDiffusionModel: + denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= None + diffusion_sampler= @sr256/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +sr256/denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @sr256/network_sr.ImagenEfficientUNet() + +sr256/samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig() +sr256/network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 512 + attn_cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 2) + width_multipliers = (1, 2, 4, 8, 8) + attn_resolutions_divs = {8: 'fused', 16: 'fused'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +sr256/samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=4 + cf_guidance_nulls={'text': None, 'text_mask': None} + +sr256/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +sr256/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +sr256/sample_script.DiffusionModelSetupData: + model = %SR256 + sampling_cfg = @sr256/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} + +#---------------- Text Model ------------------------------------------------- +import seqio +from rosetta.projects.inference_serving.t5 import network as t5x_network +from rosetta.projects.inference_serving.t5 import models as t5x_models + +# ===================================== +# === T5 Encoder only configuration === +# ===================================== +T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl" +BATCH_SIZE = 256 # Will be overridden +SEQ_LEN = 128 # MAX seqlen + +# Vocabulary +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" +TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. + +# --------------- Model ------------------ +TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel() +text_enc/t5x_models.EncoderOnlyModel: + module = @t5x_network.TransformerEncoderOnly() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = None + z_loss = 0.0001 + label_smoothing = 0.0 + loss_normalizing_factor = None + +# -------- Network specification --------- +t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config() +t5x_network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 0 + head_dim = 64 + mlp_dim = 10240 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + +text_enc/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +text_enc/utils.RestoreCheckpointConfig: + path = %T5_CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' + +text_enc/sample_script.setup_text_enc: + model=%TEXT_ENC + restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig() + partitioner=@text_enc/partitioning.PjitPartitioner() + batch_size=1 + seq_len=%TXT_SEQLEN_SINGLE + vocab = %VOCABULARY + +sample_script.sample: + base_setupdata = @base_model/sample_script.DiffusionModelSetupData() + sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData() + sr1024_setupdata = None + out_dir = %SAVE_DIR + gen_per_prompt = %GEN_PER_PROMPT + prompt_file = %PROMPT_TEXT_FILE + batch_size = %GLOBAL_BATCH_SIZE + max_images = %MAX_GENERATE + text_enc_infer = @text_enc/sample_script.setup_text_enc() + noise_conditioning_aug = %NOISE_COND_AUG + resume_from = %RESUME_FROM diff --git a/rosetta/rosetta/projects/imagen/imagen_pipe.py b/rosetta/rosetta/projects/imagen/imagen_pipe.py index fb96d7ffd..8995234e4 100644 --- a/rosetta/rosetta/projects/imagen/imagen_pipe.py +++ b/rosetta/rosetta/projects/imagen/imagen_pipe.py @@ -19,6 +19,7 @@ import functools from typing import Mapping, Any, Optional, Callable, Sequence import logging +import time import numpy as np import jax @@ -37,6 +38,8 @@ _DEFAULT_GIN_SEARCH_PATHS = [ os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ] +# This prevents issues where filenames go longer than the max length allowed in unix +MAX_FILENAME_LENGTH = 150 @dataclasses.dataclass class DiffusionModelSetupData: @@ -194,8 +197,9 @@ def sample( sampled_ctr = 0 rng = jax.random.PRNGKey(0) + start_time = time.time() for start_idx in range(resume_from, max_images, batch_size // gen_per_prompt): - if start_idx > prompt_ct: + if start_idx >= prompt_ct: break prompt_batch = prompts[start_idx: start_idx + (batch_size // gen_per_prompt)] * gen_per_prompt rng, rng_base, rng_sr, rng_sr2, rng_aug = jax.random.split(rng, 5) @@ -213,7 +217,7 @@ def sample( base_batch = {'samples': base_img_inputs, 'text': encoded_text, 'text_mask': text_mask} base_out = base_fn(base_params, base_batch, rng_base) for i in range(base_out.shape[0]): - matimg.imsave(os.path.join(base_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), np.clip(base_out[i], a_min=0, a_max=1)) + matimg.imsave(os.path.join(base_dir, sanitize_filename(f'{prompt_batch[i][:MAX_FILENAME_LENGTH]}_{sampled_ctr + i}.png')), np.clip(base_out[i], a_min=0, a_max=1)) # Stage 2: Super Resolution (64-> 256) base_aug = (base_out * 2 - 1) @@ -222,7 +226,7 @@ def sample( sr_out = sr256_fn(sr256_params, sr256_batch, rng_sr) sr_out = jnp.clip(sr_out, a_min = 0, a_max = 1) for i in range(sr_out.shape[0]): - matimg.imsave(os.path.join(sr_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), sr_out[i]) + matimg.imsave(os.path.join(sr_dir, sanitize_filename(f'{prompt_batch[i][:MAX_FILENAME_LENGTH]}_{sampled_ctr + i}.png')), sr_out[i]) # Stage 3: Super Resolution (256-> 1024) if sr1024_setupdata is not None: @@ -232,9 +236,12 @@ def sample( sr_out = sr1024_fn(sr1024_params, sr1024_batch, rng_sr2) sr_out = jnp.clip(sr_out, a_min = 0, a_max = 1) for i in range(sr_out.shape[0]): - matimg.imsave(os.path.join(sr2_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), sr_out[i]) + matimg.imsave(os.path.join(sr2_dir, sanitize_filename(f'{prompt_batch[i][:MAX_FILENAME_LENGTH]}_{sampled_ctr + i}.png')), sr_out[i]) sampled_ctr += sr_out.shape[0] + print(f'total samples generated={sampled_ctr}') + print(f'batch sec/sample={(time.time() - start_time) / sr_out.shape[0]}') + start_time = time.time() if __name__ == '__main__': diff --git a/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub b/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub index 6da0ccde1..4beb45afb 100755 --- a/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub +++ b/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub @@ -14,11 +14,9 @@ set -x # File system and volume glue code #------------------------------------------------------------------------------- # << CHANGE ! >> -SLURM_ACCOUNT= -USERID= # << CHANGE ! >> -CONTAINER= +CONTAINER=${CONTAINER:-ghcr.io#nvidia/t5x:imagen-2023-10-02.v3} # << CHANGE ! >> BASE_ROSETTA_DIR="/jax-toolbox-mirror/rosetta/" # path to your clone of the repo diff --git a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh index d4d9ce63d..afd1f48b6 100755 --- a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh +++ b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. CFG=${CFG:=2} +GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-4} +GEN_PER_PROMPT=${GEN_PER_PROMPT:-1} +SAMPLING_GIN=${SAMPLING_GIN:-/opt/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin} BASE_PATH=${BASE_PATH:=\"/opt/rosetta/runs/imagen_base/checkpoint_5000\"} SR1_PATH=${SR1_PATH:=\"/opt/rosetta/runs/efficient_sr1/checkpoint_5000\"} SR2_PATH=${SR1_PATH:=\"/opt/rosetta/runs/efficient_sr2/checkpoint_5000\"} @@ -20,14 +23,15 @@ PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE:=\"/opt/rosetta/rosetta/projects/diffusion/t export DISABLE_TE=True python /opt/rosetta/rosetta/projects/imagen/imagen_pipe.py \ - --gin_file="/opt/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin" \ + --gin_file="${SAMPLING_GIN}" \ --gin.base_model/utils.RestoreCheckpointConfig.path="${BASE_PATH}" \ --gin.sr256/utils.RestoreCheckpointConfig.path="${SR1_PATH}" \ --gin.sr1024/utils.RestoreCheckpointConfig.path="${SR2_PATH}" \ --gin.T5_CHECKPOINT_PATH="\"/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl\"" \ --gin.base_model/samplers.CFGSamplingConfig.cf_guidance_weight=${CFG} \ --gin.PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE} \ - --gin.GLOBAL_BATCH_SIZE=4 \ + --gin.GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} \ --gin.SAVE_DIR="\"generations/generations-${CFG}\"" \ - --gin.GEN_PER_PROMPT=1 \ - --gin.RESUME_FROM=0 \ No newline at end of file + --gin.GEN_PER_PROMPT=${GEN_PER_PROMPT} \ + --gin.RESUME_FROM=0 \ + $@ diff --git a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh index 19d19be73..4dc46fe87 100755 --- a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh +++ b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh @@ -13,19 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. CFG=${CFG:=2} +GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-4} +GEN_PER_PROMPT=${GEN_PER_PROMPT:-1} +SAMPLING_GIN=${SAMPLING_GIN:-/opt/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin} BASE_PATH=${BASE_PATH:=\"/opt/rosetta/runs/imagen_base/checkpoint_5000\"} SR1_PATH=${SR1_PATH:=\"/opt/rosetta/runs/efficient_sr1/checkpoint_5000\"} PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE:=\"/opt/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt\"} export DISABLE_TE=True python /opt/rosetta/rosetta/projects/imagen/imagen_pipe.py \ - --gin_file="/opt/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin" \ + --gin_file="${SAMPLING_GIN}" \ --gin.base_model/utils.RestoreCheckpointConfig.path="${BASE_PATH}" \ --gin.sr256/utils.RestoreCheckpointConfig.path="${SR1_PATH}" \ --gin.T5_CHECKPOINT_PATH="\"/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl\"" \ --gin.base_model/samplers.CFGSamplingConfig.cf_guidance_weight=${CFG} \ --gin.PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE}\ - --gin.GLOBAL_BATCH_SIZE=4 \ + --gin.GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} \ --gin.SAVE_DIR="\"generations/generations-${CFG}\"" \ - --gin.GEN_PER_PROMPT=1 \ - --gin.RESUME_FROM=0 \ No newline at end of file + --gin.GEN_PER_PROMPT=${GEN_PER_PROMPT} \ + --gin.RESUME_FROM=0 \ + $@ diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index 8486ee566..2320a7ed9 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -67,13 +67,9 @@ In order to obtain the best performance, please set the appropriate XLA flags. W The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText. ``` -XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true +XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true - --xla_gpu_enable_highest_priority_async_stream=true + --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -81,7 +77,6 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index 45cb5da2c..0ca3fd802 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -53,12 +53,8 @@ export NCCL_IB_SL=1 # Set XLA Flags export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true - --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -66,12 +62,9 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false - --xla_disable_hlo_passes=rematerialization - --xla_gpu_enable_custom_fusions=false - --xla_gpu_enable_address_computation_fusion=false" + --xla_disable_hlo_passes=rematerialization" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index a2dbd1cf0..d1829b847 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -138,11 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows: ``` -BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_highest_priority_async_stream=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 - --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... +BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm=false + --xla_gpu_all_reduce_combine_threshold_bytes=33554432 + --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... ``` # Configs