Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logging issue for unsupported torch.compile devices #3077

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 59 additions & 73 deletions recipes_source/torch_logs.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,82 @@
"""
(beta) Using TORCH_LOGS python API with torch.compile
==========================================================================================
**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""

import torch
import logging

######################################################################
#
# This tutorial introduces the ``TORCH_LOGS`` environment variable, as well as the Python API, and
# demonstrates how to apply it to observe the phases of ``torch.compile``.
#
# .. note::
#
# This tutorial requires PyTorch 2.2.0 or later.
#
#


######################################################################
# Setup
# ~~~~~~~~~~~~~~~~~~~~~
# In this example, we'll set up a simple Python function which performs an elementwise
# add and observe the compilation process with ``TORCH_LOGS`` Python API.
#
# .. note::
# Setup enhanced logging for devices that don't support torch.compile
#
# There is also an environment variable ``TORCH_LOGS``, which can be used to
# change logging settings at the command line. The equivalent environment
# variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
else:
@torch.compile()
def fn(x, y):
z = x + y
return z + 2


inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check device capabilities and handle cases where torch.compile is not supported
def check_device_and_log():
"""Check device capability and log whether torch.compile is supported."""
if torch.cuda.is_available():
# Get CUDA device capability
capability = torch.cuda.get_device_capability()
logger.info(f"CUDA Device Capability: {capability}")

if capability < (7, 0):
# Log the reason why torch.compile is not supported
logger.warning(
"torch.compile is not supported on devices with a CUDA capability less than 7.0."
)
return False
logger.info("Device supports torch.compile.")
return True
else:
logger.info("No CUDA device found. Using CPU.")
return False


# Function to apply torch.compile only if supported
def fn(x, y):
"""Simple function to add two tensors and return the result."""
z = x + y
return z + 2


def run_example_with_logging():
"""Run example using torch.compile if supported, with logging."""
# Check if the device supports torch.compile
compile_supported = check_device_and_log()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = (
torch.ones(2, 2, device=device),
torch.zeros(2, 2, device=device),
)

if compile_supported:
logger.info("Compiling the function with torch.compile...")
compiled_fn = torch.compile(fn)
else:
logger.info("Running the function without compilation...")
compiled_fn = fn # Use the uncompiled function

# Print separator and reset dynamo between each example
def separator(name):
print(f"==================={name}=========================")
torch._dynamo.reset()


separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
torch._logging.set_logs(dynamo=logging.DEBUG)
fn(*inputs)
compiled_fn(*inputs)

separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
torch._logging.set_logs(graph=True)
fn(*inputs)
compiled_fn(*inputs)

separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
torch._logging.set_logs(fusion=True)
fn(*inputs)
compiled_fn(*inputs)

separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
torch._logging.set_logs(output_code=True)
fn(*inputs)
compiled_fn(*inputs)

separator("")

######################################################################
# Conclusion
# ~~~~~~~~~~
#
# In this tutorial we introduced the TORCH_LOGS environment variable and python API
# by experimenting with a small number of the available logging options.
# To view descriptions of all available options, run any python script
# which imports torch and set TORCH_LOGS to "help".
#
# Alternatively, you can view the `torch._logging documentation`_ to see
# descriptions of all available logging options.
#
# For more information on torch.compile, see the `torch.compile tutorial`_.
#
# .. _torch._logging documentation: https://pytorch.org/docs/main/logging.html
# .. _torch.compile tutorial: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

if __name__ == "__main__":
run_example_with_logging()