Skip to content

Commit

Permalink
[RFC] Log PT2 chromium events to scuba (#133859)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/benchmark#2424

Pull Request resolved: #133859

This diff implements a bunch of views for internal scuba viewing.

TODOS that I might punt to another diff:
- Saving cache stats via counter is definitely sus here, but there's not really a good way to track "fx graph cache hit for this compile phase" right now. Will think about this more.
- We should definitely log frame id, compile id, etc
- We should definitely be logging configs. That way, we can A/B test based on whether a config is turned on.
- idk what I'm doing with compile_uuid yet, but it's useful when you want to look at samples for a single run. I think if we had mast job info this field is not needed, but it's nice to be able to drill down to a single run and get its chrome trace view or icicle view, so idk

Test Plan:
All of the above views are run with nanogpt benchmark:

```
buck run mode/opt caffe2/benchmarks/dynamo:torchbench -- --training --backend=inductor --only nanogpt --performance
```

Reviewed By: ezyang

Differential Revision: D61392607
  • Loading branch information
jamesjwu authored and facebook-github-bot committed Aug 19, 2024
1 parent 432638f commit 70de7ba
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 19 deletions.
2 changes: 1 addition & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10336,7 +10336,7 @@ def pow(x):
msg="Encountered an unexpected fallback to 'aten pow' in dynamo compiled code",
)

def test_graph_break_compilation_metrics(self):
def test_graph_break_compilation_metrics_inner(self):
def fn(x):
x.cos()
torch._dynamo.graph_break()
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
format_bytecode,
frame_phase_timing,
gen_record_file_name,
get_chromium_event_logger,
increment_frame,
is_namedtuple,
istype,
Expand Down Expand Up @@ -869,6 +870,8 @@ def format_guard_failures() -> str:
# torch/_logging/_internal.py:1064 in trace_structured
# torch/_dynamo/convert_frame.py:780 in <lambda>
convert_frame_intern = structured.intern_string(__file__)
# Initialize the ChromiumEventLogger on start
get_chromium_event_logger()
torch._logging.trace_structured(
"dynamo_start",
lambda: {
Expand Down
107 changes: 93 additions & 14 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import time
import types
import typing
import uuid
import warnings
import weakref
from contextlib import contextmanager
Expand Down Expand Up @@ -64,7 +65,7 @@
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton, has_triton_package
Expand Down Expand Up @@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent


def get_cache_stats() -> Dict[str, Any]:
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
cache_stats = {
"fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"],
"fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"],
"fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"],
}
return cache_stats


# dynamo_timed is a context manager
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
Expand Down Expand Up @@ -245,22 +256,34 @@ def dynamo_timed(
phase_name: Optional[str] = None,
fwd_only: bool = True,
):
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []

fail_type: Optional[str] = None
fail_reason: Optional[str] = None
time_spent = float("-inf")
if phase_name == "entire_frame_compile":
chromium_log.reset()
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
ChromiumEventLogger.log_event_start(key, time.time_ns())
start = time.time_ns()
chromium_log.log_event_start(key, start, None)
if phase_name:
ChromiumEventLogger.log_event_start(phase_name, time.time_ns())
chromium_log.log_event_start(phase_name, start)
yield

if phase_name:
ChromiumEventLogger.log_event_end(phase_name, time.time_ns())
ChromiumEventLogger.log_event_end(key, time.time_ns())
chromium_log.log_event_end(
phase_name,
time.time_ns(),
{"cache_stats": get_cache_stats()},
start,
)
chromium_log.log_event_end(
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
)
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
except Exception as e:
Expand Down Expand Up @@ -814,8 +837,17 @@ class ChromiumEventLogger:
a specification of the Chromium Event JSON format.
"""

@staticmethod
def __init__(self):
self.stack = ["__start__"]
# Generate a unique id for this logger, which we can use in scuba to filter down
# to a single python run.
self.id_ = str(uuid.uuid4())

# TODO: log to init/id tlparse after I add support for it
log.info("ChromiumEventLogger initialized with id %s", self.id_)

def log_event_start(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -826,18 +858,24 @@ def log_event_start(
:param time_ns Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
"""
ChromiumEventLogger._log_timed_event(
event = self._log_timed_event(
event_name,
time_ns,
"B",
metadata,
)
log_chromium_event_internal(event, self.stack, self.id_)
self.stack.append(event_name)

def reset(self) -> None:
self.stack = ["__start__"]

@staticmethod
def log_event_end(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
start_time_ns: Optional[int] = None,
) -> None:
"""
Logs the end of a single event. This function should only be
Expand All @@ -846,28 +884,53 @@ def log_event_end(
:param time_ns: Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
"""
ChromiumEventLogger._log_timed_event(
# These stack health checks currently never happen,
# but they're written this way to future proof any weird event
# overlaps in the future.
if event_name not in self.stack:
# Something went wrong, we never called start on this event,
# or it was skipped due to overlapping events below
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
return

event = self._log_timed_event(
event_name,
time_ns,
"E",
metadata,
)

@staticmethod
while event_name != self.stack[-1]:
# If the event isn't the most recent one to end, pop
# off the stack until it is.
# Since event_name in self.stack, this pop is always safe
log.warning(
"ChromiumEventLogger: Detected overlapping events, fixing stack"
)
self.stack.pop()

log_chromium_event_internal(event, self.stack, self.id_, start_time_ns)
# Finally pop the actual event off the stack
self.stack.pop()

def _log_timed_event(
self,
event_name: str,
time_ns: int,
phase: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
) -> Dict[str, Any]:
"""
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
"""
event = {
"name": event_name,
"ts": time_ns / 1000, # Chromium events are in ms
"ts": time_ns / 1000, # Chromium events are in micro seconds
"args": metadata,
"ph": phase,
# These categories are needed in all chromium traces
"cat": "dynamo_timed",
"tid": 0,
"pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
}
torch._logging.trace_structured(
Expand All @@ -876,9 +939,10 @@ def _log_timed_event(
suppress_context=False,
expect_trace_id=False, # Not every chromium event will have a trace_id
)
return event

@staticmethod
def log_instant_event(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -895,7 +959,10 @@ def log_instant_event(
"ts": time_ns / 1000,
"args": metadata,
"ph": "i",
"pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
# These categories are needed in all chromium traces
"cat": "dynamo_timed",
"tid": 0,
"pid": 0,
"s": "p", # We use "process" level instant events so they all appear on the same row in the trace.
}
torch._logging.trace_structured(
Expand All @@ -904,6 +971,18 @@ def log_instant_event(
suppress_context=False,
expect_trace_id=True,
)
# Log an instant event with the same start and end time
log_chromium_event_internal(event, self.stack, self.id_)


chromium_event_log = None

Check failure on line 978 in torch/_dynamo/utils.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [var-annotated]

Need type annotation for "chromium_event_log" (hint: "chromium_event_log: <type> | None = ...")


def get_chromium_event_logger() -> ChromiumEventLogger:
global chromium_event_log
if chromium_event_log is None:
chromium_event_log = ChromiumEventLogger()
return chromium_event_log


@dataclasses.dataclass
Expand Down
5 changes: 3 additions & 2 deletions torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch._dynamo.utils import ChromiumEventLogger, counters
from torch._dynamo.utils import counters, get_chromium_event_logger
from torch._functorch import config
from torch._inductor.codecache import (
_ident,
Expand Down Expand Up @@ -502,7 +502,8 @@ def load(
"cache_state": cache_state,
"components": debug_lines,
}
ChromiumEventLogger.log_instant_event(
chromium_log = get_chromium_event_logger()
chromium_log.log_instant_event(
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_args
)
torch._logging.trace_structured(
Expand Down
5 changes: 3 additions & 2 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import torch
import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.utils import ChromiumEventLogger, counters, dynamo_timed
from torch._dynamo.utils import counters, dynamo_timed, get_chromium_event_logger
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.codegen.rocm.compile_command import (
Expand Down Expand Up @@ -1356,7 +1356,8 @@ def load( # type: ignore[no-untyped-def]
)
assert compiled_graph is not None
cache_info["cache_state"] = cache_state
ChromiumEventLogger.log_instant_event(
chromium_log = get_chromium_event_logger()
chromium_log.log_instant_event(
f"fx_graph_cache_{cache_state}", cache_event_time, metadata=cache_info
)
torch._logging.trace_structured(
Expand Down
4 changes: 4 additions & 0 deletions torch/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,7 @@ def max_clock_rate():
def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]:
print("Uploading profile stats (fb-only otherwise no-op)")
return None


def log_chromium_event_internal(event, stack, logger_uuid, start_timestamp=None):
return None

0 comments on commit 70de7ba

Please sign in to comment.