Skip to content

Commit

Permalink
Add tensor tracing in eager mode (#251)
Browse files Browse the repository at this point in the history
We don't trace in eager mode in `ops.iree.trace_tensor`.

This adds the ability to set a tensor trace callback with a default
implementation that records the tensor to a npy file.

Signed-off-by: Boian Petkantchin <[email protected]>
  • Loading branch information
sogartar authored Nov 5, 2024
1 parent ee62366 commit 8d02bc0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
7 changes: 7 additions & 0 deletions iree/turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

"""Custom ops for built-in IREE functionality."""
from typing import cast
import numpy as np
import os

from ..support.ir_imports import (
Attribute,
Expand All @@ -24,6 +26,8 @@
def_library,
)

from ..support import debugging

__all__ = [
"trace",
]
Expand All @@ -46,6 +50,9 @@ def select(self, ksel: KernelSelection):
ksel.attr_str(0)
ksel.arg_tensor(1, inplace_tied=True)

def eager_execute(self, key, tensor):
debugging.trace_tensor_callback(key, tensor)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
key = cast(AttrArg, ksel.arg_descs[0])
_emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]])
Expand Down
20 changes: 19 additions & 1 deletion iree/turbine/support/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@

"""Debug flags and settings."""

from typing import Optional
from typing import Callable, Optional
from dataclasses import dataclass
import logging
import re
import os
import torch
import numpy as np

__all__ = [
"default_trace_tensor_callback",
"flags",
"NDEBUG",
"trace_tensor_callback",
"trace_tensor_to_npy",
"TraceTensorCallback",
]

# We use the native logging vs our .logging setup because our logging depends
Expand Down Expand Up @@ -99,3 +105,15 @@ def parse_from_env() -> "DebugFlags":


flags = DebugFlags.parse_from_env()

TraceKey = str
TraceTensorCallback = Callable[[TraceKey, torch.Tensor], None]


def trace_tensor_to_npy(key: TraceKey, tensor: torch.Tensor):
if flags.runtime_trace_dir is not None:
np.save(os.path.join(flags.runtime_trace_dir, f"{key}.npy"), tensor)


default_trace_tensor_callback = trace_tensor_to_npy
trace_tensor_callback = default_trace_tensor_callback
16 changes: 14 additions & 2 deletions tests/ops/iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,32 @@

import logging
import unittest

import tempfile
import torch
import torch.nn as nn
import numpy as np
import os

import iree.turbine.aot as aot
import iree.turbine.ops as ops
import iree.turbine.support.debugging as debugging


# See runtime/op_reg/kernel_aot_test.py for additional tests of the trace
# op.
class TraceTensorTest(unittest.TestCase):
def testEager(self):
t = torch.randn(3, 4)
ops.iree.trace_tensor("TEST", t)
with tempfile.TemporaryDirectory() as tmp_dir:
stashed_runtime_trace_dir = debugging.flags.runtime_trace_dir
debugging.flags.runtime_trace_dir = tmp_dir

ops.iree.trace_tensor("TEST", t)
recorded_tensor = np.load(os.path.join(tmp_dir, "TEST.npy"))
np.testing.assert_equal(recorded_tensor, t)

# recover the original so we don't influence other tests.
debugging.flags.runtime_trace_dir = stashed_runtime_trace_dir

def testAOT(self):
class MyModule(nn.Module):
Expand Down

0 comments on commit 8d02bc0

Please sign in to comment.