Skip to content

Commit

Permalink
refactor(di): simplify context capturing API
Browse files Browse the repository at this point in the history
We clean-up the internal context capturing API after the recent changes
that have introduced support for local variables in function probes and
the exposure of globals to expression/condition evaluations.
  • Loading branch information
P403n1x87 committed Aug 23, 2024
1 parent d0244e3 commit c5b32a9
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 105 deletions.
6 changes: 0 additions & 6 deletions ddtrace/debugging/_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from ddtrace.debugging._probe.remoteconfig import ProbePollerEventType
from ddtrace.debugging._probe.remoteconfig import ProbeRCAdapter
from ddtrace.debugging._probe.status import ProbeStatusLogger
from ddtrace.debugging._safety import get_args
from ddtrace.debugging._signal.collector import SignalCollector
from ddtrace.debugging._signal.collector import SignalContext
from ddtrace.debugging._signal.metric_sample import MetricSample
Expand Down Expand Up @@ -175,7 +174,6 @@ def _open_contexts(self) -> None:
frame = self.__frame__
assert frame is not None # nosec

args = list(get_args(frame))
thread = threading.current_thread()

signal: Optional[Signal] = None
Expand All @@ -200,7 +198,6 @@ def _open_contexts(self) -> None:
probe=probe,
frame=frame,
thread=thread,
args=args,
trace_context=trace_context,
meter=self._probe_meter,
)
Expand All @@ -209,23 +206,20 @@ def _open_contexts(self) -> None:
probe=probe,
frame=frame,
thread=thread,
args=args,
trace_context=trace_context,
)
elif isinstance(probe, SpanFunctionProbe):
signal = DynamicSpan(
probe=probe,
frame=frame,
thread=thread,
args=args,
trace_context=trace_context,
)
elif isinstance(probe, SpanDecorationFunctionProbe):
signal = SpanDecoration(
probe=probe,
frame=frame,
thread=thread,
args=args,
)
else:
log.error("Unsupported probe type: %s", type(probe))
Expand Down
13 changes: 7 additions & 6 deletions ddtrace/debugging/_probe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Union
Expand Down Expand Up @@ -198,36 +199,36 @@ class MetricFunctionProbe(Probe, FunctionLocationMixin, MetricProbeMixin, ProbeC
@dataclass
class TemplateSegment(abc.ABC):
@abc.abstractmethod
def eval(self, _locals: Dict[str, Any]) -> str:
def eval(self, scope: Mapping[str, Any]) -> str:
pass


@dataclass
class LiteralTemplateSegment(TemplateSegment):
str_value: str

def eval(self, _locals: Dict[str, Any]) -> Any:
def eval(self, _scope: Mapping[str, Any]) -> Any:
return self.str_value


@dataclass
class ExpressionTemplateSegment(TemplateSegment):
expr: DDExpression

def eval(self, _locals: Dict[str, Any]) -> Any:
return self.expr.eval(_locals)
def eval(self, scope: Mapping[str, Any]) -> Any:
return self.expr.eval(scope)


@dataclass
class StringTemplate:
template: str
segments: List[TemplateSegment]

def render(self, _locals: Dict[str, Any], serializer: Callable[[Any], str]) -> str:
def render(self, scope: Mapping[str, Any], serializer: Callable[[Any], str]) -> str:
def _to_str(value):
return value if _isinstance(value, str) else serializer(value)

return "".join([_to_str(s.eval(_locals)) for s in self.segments])
return "".join([_to_str(s.eval(scope)) for s in self.segments])


@dataclass
Expand Down
8 changes: 4 additions & 4 deletions ddtrace/debugging/_signal/metric_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def exit(self, retval, exc_info, duration) -> None:
return

probe = self.probe
_locals = self._enrich_locals(retval, exc_info, duration)
full_scope = self.get_full_scope(retval, exc_info, duration)

if probe.evaluate_at != ProbeEvaluateTimingForMethod.EXIT:
if probe.evaluate_at is not ProbeEvaluateTimingForMethod.EXIT:
return
if not self._eval_condition(_locals):
if not self._eval_condition(full_scope):
return

self.sample(_locals)
self.sample(full_scope)
self.state = SignalState.DONE

def line(self) -> None:
Expand Down
26 changes: 15 additions & 11 deletions ddtrace/debugging/_signal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast
from uuid import uuid4
Expand All @@ -22,6 +22,8 @@
from ddtrace.debugging._probe.model import LineLocationMixin
from ddtrace.debugging._probe.model import Probe
from ddtrace.debugging._probe.model import ProbeConditionMixin
from ddtrace.debugging._safety import get_args
from ddtrace.internal.compat import ExcInfoType
from ddtrace.internal.rate_limiter import RateLimitExceeded


Expand Down Expand Up @@ -52,21 +54,20 @@ class Signal(abc.ABC):
frame: FrameType
thread: Thread
trace_context: Optional[Union[Span, Context]] = None
args: Optional[List[Tuple[str, Any]]] = None
state: str = SignalState.NONE
errors: List[EvaluationError] = field(default_factory=list)
timestamp: float = field(default_factory=time.time)
uuid: str = field(default_factory=lambda: str(uuid4()), init=False)

def _eval_condition(self, _locals: Optional[Dict[str, Any]] = None) -> bool:
def _eval_condition(self, scope: Optional[Mapping[str, Any]] = None) -> bool:
"""Evaluate the probe condition against the collected frame."""
probe = cast(ProbeConditionMixin, self.probe)
condition = probe.condition
if condition is None:
return True

try:
if bool(condition.eval(_locals or self.frame.f_locals)):
if bool(condition.eval(scope)):
return True
except DDExpressionEvaluationError as e:
self.errors.append(EvaluationError(expr=e.dsl, message=e.error))
Expand All @@ -80,19 +81,22 @@ def _eval_condition(self, _locals: Optional[Dict[str, Any]] = None) -> bool:

return False

def _enrich_locals(self, retval, exc_info, duration):
def get_full_scope(self, retval: Any, exc_info: ExcInfoType, duration: float) -> ChainMap[str, Any]:
frame = self.frame
_locals = dict(frame.f_locals)
_locals["@duration"] = duration / 1e6 # milliseconds
extra: Dict[str, Any] = {"@duration": duration / 1e6} # milliseconds

exc = exc_info[1]
if exc is not None:
_locals["@exception"] = exc
extra["@exception"] = exc
else:
_locals["@return"] = retval
extra["@return"] = retval

# Include the frame globals.
return ChainMap(_locals, frame.f_globals)
# Include the frame locals and globals.
return ChainMap(extra, frame.f_locals, frame.f_globals)

@property
def args(self):
return dict(get_args(self.frame))

@abc.abstractmethod
def enter(self):
Expand Down
83 changes: 35 additions & 48 deletions ddtrace/debugging/_signal/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections import ChainMap
from dataclasses import dataclass
from dataclasses import field
from itertools import chain
import sys
from types import FrameType
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import cast

from ddtrace.debugging import _safety
from ddtrace.debugging._expressions import DDExpressionEvaluationError
from ddtrace.debugging._probe.model import DEFAULT_CAPTURE_LIMITS
from ddtrace.debugging._probe.model import CaptureLimits
Expand All @@ -22,6 +22,9 @@
from ddtrace.debugging._probe.model import TemplateSegment
from ddtrace.debugging._redaction import REDACTED_PLACEHOLDER
from ddtrace.debugging._redaction import DDRedactedExpressionError
from ddtrace.debugging._safety import get_args
from ddtrace.debugging._safety import get_globals
from ddtrace.debugging._safety import get_locals
from ddtrace.debugging._signal import utils
from ddtrace.debugging._signal.model import EvaluationError
from ddtrace.debugging._signal.model import LogSignal
Expand All @@ -35,18 +38,30 @@
CAPTURE_TIME_BUDGET = 0.2 # seconds


_NOTSET = object()


def _capture_context(
arguments: List[Tuple[str, Any]],
_locals: List[Tuple[str, Any]],
_globals: List[Tuple[str, Any]],
frame: FrameType,
throwable: ExcInfoType,
retval: Any = _NOTSET,
limits: CaptureLimits = DEFAULT_CAPTURE_LIMITS,
) -> Dict[str, Any]:
with HourGlass(duration=CAPTURE_TIME_BUDGET) as hg:

def timeout(_):
return not hg.trickling()

arguments = get_args(frame)
_locals = get_locals(frame)
_globals = get_globals(frame)

_, exc, _ = throwable
if exc is not None:
_locals = chain(_locals, [("@exception", exc)])
elif retval is not _NOTSET:
_locals = chain(_locals, [("@return", retval)])

return {
"arguments": utils.capture_pairs(
arguments, limits.max_level, limits.max_len, limits.max_size, limits.max_fields, timeout
Expand All @@ -67,13 +82,7 @@ def timeout(_):
}


_EMPTY_CAPTURED_CONTEXT = _capture_context(
arguments=[],
_locals=[],
_globals=[],
throwable=(None, None, None),
limits=DEFAULT_CAPTURE_LIMITS,
)
_EMPTY_CAPTURED_CONTEXT: Dict[str, Any] = {"arguments": {}, "locals": {}, "staticFields": {}, "throwable": None}


@dataclass
Expand Down Expand Up @@ -117,71 +126,55 @@ def enter(self):

probe = self.probe
frame = self.frame
_args = list(self.args or _safety.get_args(frame))

if probe.evaluate_at == ProbeEvaluateTimingForMethod.EXIT:
return

if not self._eval_condition(dict(_args)):
_args = self.args

if not self._eval_condition(_args):
return

if probe.limiter.limit() is RateLimitExceeded:
self.state = SignalState.SKIP_RATE
return

if probe.take_snapshot:
self.entry_capture = _capture_context(
_args,
[],
[],
(None, None, None),
limits=probe.limits,
)
self.entry_capture = _capture_context(frame, (None, None, None), limits=probe.limits)

if probe.evaluate_at == ProbeEvaluateTimingForMethod.ENTER:
self._eval_message(dict(_args))
self._eval_message(_args)
self.state = SignalState.DONE

def exit(self, retval, exc_info, duration):
if not isinstance(self.probe, LogFunctionProbe):
return

probe = self.probe
_locals = self._enrich_locals(retval, exc_info, duration)
full_scope = self.get_full_scope(retval, exc_info, duration)

if probe.evaluate_at == ProbeEvaluateTimingForMethod.EXIT:
if not self._eval_condition(_locals):
if not self._eval_condition(full_scope):
return
if probe.limiter.limit() is RateLimitExceeded:
self.state = SignalState.SKIP_RATE
return
elif self.state not in {SignalState.NONE, SignalState.DONE}:
return

_pure_locals = list(_safety.get_locals(self.frame))
_, exc, tb = exc_info
if exc is None:
_pure_locals.append(("@return", retval))
else:
_pure_locals.append(("@exception", exc))

if probe.take_snapshot:
self.return_capture = _capture_context(
self.args or _safety.get_args(self.frame),
_pure_locals,
_safety.get_globals(self.frame),
exc_info,
limits=probe.limits,
)
self.return_capture = _capture_context(self.frame, exc_info, retval=retval, limits=probe.limits)

self.duration = duration
self.state = SignalState.DONE
if probe.evaluate_at != ProbeEvaluateTimingForMethod.ENTER:
self._eval_message(dict(_locals))
self._eval_message(full_scope)

stack = utils.capture_stack(self.frame)

# Fix the line number of the top frame. This might have been mangled by
# the instrumented exception handling of function probes.
tb = exc_info[2]
while tb is not None:
frame = tb.tb_frame
if frame == self.frame:
Expand All @@ -206,15 +199,9 @@ def line(self):
self.state = SignalState.SKIP_RATE
return

self.line_capture = _capture_context(
self.args or _safety.get_args(frame),
_safety.get_locals(frame),
_safety.get_globals(frame),
sys.exc_info(),
limits=probe.limits,
)
self.line_capture = _capture_context(frame, sys.exc_info(), limits=probe.limits)

self._eval_message(frame.f_locals)
self._eval_message(ChainMap(frame.f_locals, frame.f_globals))

self._stack = utils.capture_stack(frame)

Expand Down
Loading

0 comments on commit c5b32a9

Please sign in to comment.