From 803ce507f1d150ad222989c8cb3a3e717ad60e69 Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 19 Sep 2024 16:11:38 +0000 Subject: [PATCH] Log structured logging overhead to dynamo compile (kinda) (#136142) Summary: X-link: https://github.com/pytorch/benchmark/pull/2454 This adds structured logging overhead at a per compile basis to compilation metrics. To do so, we track the frame_id_frame_compile_id that trace_structured uses to categorize compiles, and use that as the key in our timing table. Implementation notes: - If there's times we call trace_structured without a compile id, the time won't be measured. Not really a good way around that today given the compile id framework of compilation metrics. Strobelight is still the best way to measure on a per job basis. - We don't actually measure the time it takes to log the compilation metrics itself. Fundamentally, it's not possible to log this properly if we're storing the logging number *in* compilation metrics, since there's no way to measure it before we do it(unless we want discrepancies between dynamo_compile and tlparse, which seems suboptimal). Hopefully for a large job, the cost of structured_logging compilation metrics itself is small. - I wanted to use frame_phase_timing here, but there's a bunch of ids to iron out, and I don't really want to deal with that headache. compilation_time_metrics is sort of what I want, but that isn't by frame/compile id, so it's also a bit off. Putting it into torch.logging as a separate thing so logging tracks its own overhead seems fine, though. Test Plan: Run benchmarks/nanogpt and staging logger. See that the new compilation metric is logged to the staged dynamo_compile table: https://fburl.com/scuba/logger_staging_jjwu_30582a48f1ff9cf5f4ac50a4c40af/xazjg5xq Note that the sum(structured_logging_overhead_s) / sum(entire_frame_compile_time) = 8.387 / 124.278 = 6%, which seems reasonable as the overhead for a small compilation like this. You can also look at samples for a more detailed log of this. Reviewed By: oulgen Differential Revision: D62643611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136142 Approved by: https://github.com/bobrenjc93 --- torch/_dynamo/convert_frame.py | 5 ++++ torch/_dynamo/utils.py | 11 +++++++++ torch/_logging/__init__.py | 1 + torch/_logging/_internal.py | 45 ++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 171af02e564b3..1b71c42b9ac5a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1028,6 +1028,10 @@ def format_guard_failures() -> str: possibly_missed_reinplacing_opportunities = None remote_cache_time_saved = None + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) + metrics = CompilationMetrics( str(compile_id), frame_key, @@ -1057,6 +1061,7 @@ def format_guard_failures() -> str: guarded_code is not None, possibly_missed_reinplacing_opportunities, remote_cache_time_saved, + structured_logging_overhead_s, ) record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index e086f11e174cd..7d34671b11a3c 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -353,6 +353,9 @@ def dynamo_timed( inductor_compile_time = None code_gen_time = None remote_cache_time_saved = None + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) metrics = BwdCompilationMetrics( compile_id, inductor_compile_time, @@ -360,6 +363,7 @@ def dynamo_timed( fail_type, fail_reason, remote_cache_time_saved, + structured_logging_overhead_s, ) record_compilation_metrics(metrics) @@ -799,6 +803,7 @@ class CompilationMetrics: has_guarded_code: bool possibly_missed_reinplacing_opportunities: Optional[int] remote_cache_time_saved_s: Optional[float] + structured_logging_overhead_s: Optional[float] @dataclasses.dataclass @@ -809,6 +814,7 @@ class BwdCompilationMetrics: fail_type: Optional[str] fail_reason: Optional[str] remote_cache_time_saved_s: Optional[float] + structured_logging_overhead_s: Optional[float] DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -834,6 +840,11 @@ def record_compilation_metrics( k: list(v) if isinstance(v, set) else v for k, v in dataclasses.asdict(compilation_metrics).items() }, + # NB: Because compilation metrics *includes* the logging overhead time, + # we can't both *measure* the logging overhead of compilation metrics + # without making it inconsistent with compilation metrics itself, so + # we ignore the (hopefully small) time spent logging compilation metrics + record_logging_overhead=False, ) if config.log_compilation_metrics: log_compilation_event(compilation_metrics) diff --git a/torch/_logging/__init__.py b/torch/_logging/__init__.py index 0531869ae2fdf..5acf175c27522 100644 --- a/torch/_logging/__init__.py +++ b/torch/_logging/__init__.py @@ -9,6 +9,7 @@ from ._internal import ( _init_logs, DEFAULT_LOGGING, + get_structured_logging_overhead, getArtifactLogger, LazyString, set_logs, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index f11af08e3b768..f78396545c1f7 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -10,6 +10,8 @@ import re import sys import tempfile +import time +from collections import defaultdict from dataclasses import dataclass, field from importlib import __import__ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -1091,6 +1093,42 @@ def __str__(self): return self.func(*self.args, **self.kwargs) +# Logs the time it takes to do structured logging by frame/compile id +# key is always {frame_id}_{frame_compile_id} +structured_logging_overhead: Dict[str, float] = defaultdict(float) + + +# Same principle as add_remote_cache_time_saved, but do it for structured logging +def add_structured_logging_overhead(time_spent: float) -> None: + global structured_logging_overhead + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + # Why not trace_id.attempt, like structured logging? + # We aggregate across all attempts because + # a compilation metric is logged per successful attempt + key = f"{frame_id}_{frame_compile_id}" + # TODO: deal with structured logging that occurs outside of specific compile ids + # It's hard to figure out where we would log that if we want it in compilation metrics + # itself. + if key is not None: + key = str(key) + structured_logging_overhead[key] += time_spent + + +def get_structured_logging_overhead() -> Optional[float]: + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + key = f"{frame_id}_{frame_compile_id}" + if key is not None: + return structured_logging_overhead.get(key) + else: + return None + + def trace_structured( name: str, # NB: metadata expected to be dict so adding more info is forward compatible @@ -1100,6 +1138,7 @@ def trace_structured( payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, suppress_context: bool = False, expect_trace_id: bool = True, # Whether or not we expect to have a current trace id + record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging ): """ metadata is an arbitrary JSON compatible struct, but it's expected to not be @@ -1118,6 +1157,7 @@ def trace_structured( # trace_log never propagates and is ALWAYS DEBUG, so also check that there # are handlers instead of checking the log level if trace_log.handlers: + start_time = time.time_ns() record: Dict[str, object] = {} record[name] = metadata_fn() if not suppress_context: @@ -1156,6 +1196,11 @@ def trace_structured( ) log_trace_structured_event(name, record) + if record_logging_overhead: + # Convert to seconds from nanoseconds, add it to the frame compile total + structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9 + add_structured_logging_overhead(structured_logging_overhead_s) + import torch._guards import torch._utils_internal