Skip to content

Commit

Permalink
Remove implicit dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Aug 8, 2024
1 parent f78938a commit 4e6449a
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions .github/container/nsys-jax
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,11 @@ def copy_proto_files_to_tmp(tmp_dir, xla_dir="/opt/xla"):
return proto_dir, proto_files


def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue, wait_on):
def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
"""
Post-process a .nsys-rep file into a .parquet file for offline analysis.
This is currently implemented using the given nsys recipe.
"""
wait_on.result()
start = time.time()
recipe_output = osp.join(tmp_dir, recipe)
subprocess.run(
Expand Down Expand Up @@ -377,8 +376,11 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
report,
"--input",
report_file,
# avoid race conditions with other reports/etc.
"--sqlite",
osp.splitext(report_file)[0] + "-" + report + ".sqlite",
"--output",
".",
"report",
]
+ (["--force-overwrite"] if nsys_force_overwrite else []),
check=True,
Expand All @@ -388,15 +390,28 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")


def save_device_stream_thread_names(tmp_dir, report, output_queue, wait_on):
def save_device_stream_thread_names(tmp_dir, report, output_queue):
"""
Extract extra information from the SQLite dump that is needed to map projected NVTX
ranges to global device IDs.
"""
wait_on.result()
start = time.time()
assert report.endswith(".nsys-rep"), f"{report} had an unexpected suffix"
db_file = report.removesuffix(".nsys-rep") + ".sqlite"
db_file = report.removesuffix(".nsys-rep") + "-metadata.sqlite"
subprocess.run(
[
"nsys",
"export",
"--type",
"sqlite",
"--tables",
"StringIds,TARGET_INFO_GPU,TARGET_INFO_NVTX_CUDA_DEVICE,TARGET_INFO_SYSTEM_ENV,ThreadNames",
"--output",
db_file,
report,
],
check=True,
)
assert os.path.exists(db_file)
con = sqlite3.connect(db_file)
cur = con.cursor()
Expand Down Expand Up @@ -736,6 +751,16 @@ with ThreadPoolExecutor() as executor, output_thread(executor):
compress_deflate,
)
)
# Convert .nsys-rep -> .parquet and queue the latter for archival
futures.append(
executor.submit(
run_nsys_recipe,
"nvtx_gpu_proj_trace",
tmp_rep,
tmp_dir,
files_to_archive,
)
)
# Copy /opt/jax_nsys into the archive
futures.append(
executor.submit(copy_jax_nsys_files, "/opt/jax_nsys", files_to_archive)
Expand All @@ -756,36 +781,22 @@ with ThreadPoolExecutor() as executor, output_thread(executor):
futures,
)
)
# This implicitly creates the .sqlite export file; in nsys before 2024.5 then so
# did the nvtx_gpu_proj_trace recipe, but not anymore.
futures.append(
sqlite_exists_future := executor.submit(
executor.submit(
run_nsys_stats_report,
"nvtx_pushpop_trace",
tmp_rep,
tmp_dir,
files_to_archive,
)
)
# Convert .nsys-rep -> .parquet and queue the latter for archival
futures.append(
executor.submit(
run_nsys_recipe,
"nvtx_gpu_proj_trace",
tmp_rep,
tmp_dir,
files_to_archive,
sqlite_exists_future, # for dependency purposes only
)
)
# Do some custom post-processing of the .sqlite export generated by gpu_proj_future
futures.append(
executor.submit(
save_device_stream_thread_names,
tmp_dir,
tmp_rep,
files_to_archive,
sqlite_exists_future, # for dependency purposes only
)
)
# Wait for errors/completion of `futures`; note that this does not include
Expand Down

0 comments on commit 4e6449a

Please sign in to comment.