diff --git a/ddtrace/debugging/_debugger.py b/ddtrace/debugging/_debugger.py index 6c991dc8711..a93d2077cae 100644 --- a/ddtrace/debugging/_debugger.py +++ b/ddtrace/debugging/_debugger.py @@ -45,7 +45,6 @@ from ddtrace.debugging._probe.remoteconfig import ProbePollerEventType from ddtrace.debugging._probe.remoteconfig import ProbeRCAdapter from ddtrace.debugging._probe.status import ProbeStatusLogger -from ddtrace.debugging._safety import get_args from ddtrace.debugging._signal.collector import SignalCollector from ddtrace.debugging._signal.collector import SignalContext from ddtrace.debugging._signal.metric_sample import MetricSample @@ -175,7 +174,6 @@ def _open_contexts(self) -> None: frame = self.__frame__ assert frame is not None # nosec - args = list(get_args(frame)) thread = threading.current_thread() signal: Optional[Signal] = None @@ -200,7 +198,6 @@ def _open_contexts(self) -> None: probe=probe, frame=frame, thread=thread, - args=args, trace_context=trace_context, meter=self._probe_meter, ) @@ -209,7 +206,6 @@ def _open_contexts(self) -> None: probe=probe, frame=frame, thread=thread, - args=args, trace_context=trace_context, ) elif isinstance(probe, SpanFunctionProbe): @@ -217,7 +213,6 @@ def _open_contexts(self) -> None: probe=probe, frame=frame, thread=thread, - args=args, trace_context=trace_context, ) elif isinstance(probe, SpanDecorationFunctionProbe): @@ -225,7 +220,6 @@ def _open_contexts(self) -> None: probe=probe, frame=frame, thread=thread, - args=args, ) else: log.error("Unsupported probe type: %s", type(probe)) diff --git a/ddtrace/debugging/_probe/model.py b/ddtrace/debugging/_probe/model.py index f6b98e1366f..d96b98eaf46 100644 --- a/ddtrace/debugging/_probe/model.py +++ b/ddtrace/debugging/_probe/model.py @@ -8,6 +8,7 @@ from typing import Callable from typing import Dict from typing import List +from typing import Mapping from typing import Optional from typing import Tuple from typing import Union @@ -198,7 +199,7 @@ class MetricFunctionProbe(Probe, FunctionLocationMixin, MetricProbeMixin, ProbeC @dataclass class TemplateSegment(abc.ABC): @abc.abstractmethod - def eval(self, _locals: Dict[str, Any]) -> str: + def eval(self, scope: Mapping[str, Any]) -> str: pass @@ -206,7 +207,7 @@ def eval(self, _locals: Dict[str, Any]) -> str: class LiteralTemplateSegment(TemplateSegment): str_value: str - def eval(self, _locals: Dict[str, Any]) -> Any: + def eval(self, _scope: Mapping[str, Any]) -> Any: return self.str_value @@ -214,8 +215,8 @@ def eval(self, _locals: Dict[str, Any]) -> Any: class ExpressionTemplateSegment(TemplateSegment): expr: DDExpression - def eval(self, _locals: Dict[str, Any]) -> Any: - return self.expr.eval(_locals) + def eval(self, scope: Mapping[str, Any]) -> Any: + return self.expr.eval(scope) @dataclass @@ -223,11 +224,11 @@ class StringTemplate: template: str segments: List[TemplateSegment] - def render(self, _locals: Dict[str, Any], serializer: Callable[[Any], str]) -> str: + def render(self, scope: Mapping[str, Any], serializer: Callable[[Any], str]) -> str: def _to_str(value): return value if _isinstance(value, str) else serializer(value) - return "".join([_to_str(s.eval(_locals)) for s in self.segments]) + return "".join([_to_str(s.eval(scope)) for s in self.segments]) @dataclass diff --git a/ddtrace/debugging/_signal/metric_sample.py b/ddtrace/debugging/_signal/metric_sample.py index 896735e818d..c14f0173734 100644 --- a/ddtrace/debugging/_signal/metric_sample.py +++ b/ddtrace/debugging/_signal/metric_sample.py @@ -40,14 +40,14 @@ def exit(self, retval, exc_info, duration) -> None: return probe = self.probe - _locals = self._enrich_locals(retval, exc_info, duration) + full_scope = self.get_full_scope(retval, exc_info, duration) - if probe.evaluate_at != ProbeEvaluateTimingForMethod.EXIT: + if probe.evaluate_at is not ProbeEvaluateTimingForMethod.EXIT: return - if not self._eval_condition(_locals): + if not self._eval_condition(full_scope): return - self.sample(_locals) + self.sample(full_scope) self.state = SignalState.DONE def line(self) -> None: diff --git a/ddtrace/debugging/_signal/model.py b/ddtrace/debugging/_signal/model.py index 137568ecc0e..64cb0bff555 100644 --- a/ddtrace/debugging/_signal/model.py +++ b/ddtrace/debugging/_signal/model.py @@ -9,8 +9,8 @@ from typing import Any from typing import Dict from typing import List +from typing import Mapping from typing import Optional -from typing import Tuple from typing import Union from typing import cast from uuid import uuid4 @@ -22,6 +22,8 @@ from ddtrace.debugging._probe.model import LineLocationMixin from ddtrace.debugging._probe.model import Probe from ddtrace.debugging._probe.model import ProbeConditionMixin +from ddtrace.debugging._safety import get_args +from ddtrace.internal.compat import ExcInfoType from ddtrace.internal.rate_limiter import RateLimitExceeded @@ -52,13 +54,12 @@ class Signal(abc.ABC): frame: FrameType thread: Thread trace_context: Optional[Union[Span, Context]] = None - args: Optional[List[Tuple[str, Any]]] = None state: str = SignalState.NONE errors: List[EvaluationError] = field(default_factory=list) timestamp: float = field(default_factory=time.time) uuid: str = field(default_factory=lambda: str(uuid4()), init=False) - def _eval_condition(self, _locals: Optional[Dict[str, Any]] = None) -> bool: + def _eval_condition(self, scope: Optional[Mapping[str, Any]] = None) -> bool: """Evaluate the probe condition against the collected frame.""" probe = cast(ProbeConditionMixin, self.probe) condition = probe.condition @@ -66,7 +67,7 @@ def _eval_condition(self, _locals: Optional[Dict[str, Any]] = None) -> bool: return True try: - if bool(condition.eval(_locals or self.frame.f_locals)): + if bool(condition.eval(scope)): return True except DDExpressionEvaluationError as e: self.errors.append(EvaluationError(expr=e.dsl, message=e.error)) @@ -80,19 +81,22 @@ def _eval_condition(self, _locals: Optional[Dict[str, Any]] = None) -> bool: return False - def _enrich_locals(self, retval, exc_info, duration): + def get_full_scope(self, retval: Any, exc_info: ExcInfoType, duration: float) -> Mapping[str, Any]: frame = self.frame - _locals = dict(frame.f_locals) - _locals["@duration"] = duration / 1e6 # milliseconds + extra: Dict[str, Any] = {"@duration": duration / 1e6} # milliseconds exc = exc_info[1] if exc is not None: - _locals["@exception"] = exc + extra["@exception"] = exc else: - _locals["@return"] = retval + extra["@return"] = retval - # Include the frame globals. - return ChainMap(_locals, frame.f_globals) + # Include the frame locals and globals. + return ChainMap(extra, frame.f_locals, frame.f_globals) + + @property + def args(self): + return dict(get_args(self.frame)) @abc.abstractmethod def enter(self): diff --git a/ddtrace/debugging/_signal/snapshot.py b/ddtrace/debugging/_signal/snapshot.py index 4902ea65721..292304413cc 100644 --- a/ddtrace/debugging/_signal/snapshot.py +++ b/ddtrace/debugging/_signal/snapshot.py @@ -1,14 +1,14 @@ +from collections import ChainMap from dataclasses import dataclass from dataclasses import field +from itertools import chain import sys +from types import FrameType from typing import Any from typing import Dict -from typing import List from typing import Optional -from typing import Tuple from typing import cast -from ddtrace.debugging import _safety from ddtrace.debugging._expressions import DDExpressionEvaluationError from ddtrace.debugging._probe.model import DEFAULT_CAPTURE_LIMITS from ddtrace.debugging._probe.model import CaptureLimits @@ -22,6 +22,9 @@ from ddtrace.debugging._probe.model import TemplateSegment from ddtrace.debugging._redaction import REDACTED_PLACEHOLDER from ddtrace.debugging._redaction import DDRedactedExpressionError +from ddtrace.debugging._safety import get_args +from ddtrace.debugging._safety import get_globals +from ddtrace.debugging._safety import get_locals from ddtrace.debugging._signal import utils from ddtrace.debugging._signal.model import EvaluationError from ddtrace.debugging._signal.model import LogSignal @@ -35,11 +38,13 @@ CAPTURE_TIME_BUDGET = 0.2 # seconds +_NOTSET = object() + + def _capture_context( - arguments: List[Tuple[str, Any]], - _locals: List[Tuple[str, Any]], - _globals: List[Tuple[str, Any]], + frame: FrameType, throwable: ExcInfoType, + retval: Any = _NOTSET, limits: CaptureLimits = DEFAULT_CAPTURE_LIMITS, ) -> Dict[str, Any]: with HourGlass(duration=CAPTURE_TIME_BUDGET) as hg: @@ -47,6 +52,16 @@ def _capture_context( def timeout(_): return not hg.trickling() + arguments = get_args(frame) + _locals = get_locals(frame) + _globals = get_globals(frame) + + _, exc, _ = throwable + if exc is not None: + _locals = chain(_locals, [("@exception", exc)]) + elif retval is not _NOTSET: + _locals = chain(_locals, [("@return", retval)]) + return { "arguments": utils.capture_pairs( arguments, limits.max_level, limits.max_len, limits.max_size, limits.max_fields, timeout @@ -67,13 +82,7 @@ def timeout(_): } -_EMPTY_CAPTURED_CONTEXT = _capture_context( - arguments=[], - _locals=[], - _globals=[], - throwable=(None, None, None), - limits=DEFAULT_CAPTURE_LIMITS, -) +_EMPTY_CAPTURED_CONTEXT: Dict[str, Any] = {"arguments": {}, "locals": {}, "staticFields": {}, "throwable": None} @dataclass @@ -117,12 +126,13 @@ def enter(self): probe = self.probe frame = self.frame - _args = list(self.args or _safety.get_args(frame)) if probe.evaluate_at == ProbeEvaluateTimingForMethod.EXIT: return - if not self._eval_condition(dict(_args)): + _args = self.args + + if not self._eval_condition(_args): return if probe.limiter.limit() is RateLimitExceeded: @@ -130,16 +140,10 @@ def enter(self): return if probe.take_snapshot: - self.entry_capture = _capture_context( - _args, - [], - [], - (None, None, None), - limits=probe.limits, - ) + self.entry_capture = _capture_context(frame, (None, None, None), limits=probe.limits) if probe.evaluate_at == ProbeEvaluateTimingForMethod.ENTER: - self._eval_message(dict(_args)) + self._eval_message(_args) self.state = SignalState.DONE def exit(self, retval, exc_info, duration): @@ -147,10 +151,10 @@ def exit(self, retval, exc_info, duration): return probe = self.probe - _locals = self._enrich_locals(retval, exc_info, duration) + full_scope = self.get_full_scope(retval, exc_info, duration) if probe.evaluate_at == ProbeEvaluateTimingForMethod.EXIT: - if not self._eval_condition(_locals): + if not self._eval_condition(full_scope): return if probe.limiter.limit() is RateLimitExceeded: self.state = SignalState.SKIP_RATE @@ -158,30 +162,19 @@ def exit(self, retval, exc_info, duration): elif self.state not in {SignalState.NONE, SignalState.DONE}: return - _pure_locals = list(_safety.get_locals(self.frame)) - _, exc, tb = exc_info - if exc is None: - _pure_locals.append(("@return", retval)) - else: - _pure_locals.append(("@exception", exc)) - if probe.take_snapshot: - self.return_capture = _capture_context( - self.args or _safety.get_args(self.frame), - _pure_locals, - _safety.get_globals(self.frame), - exc_info, - limits=probe.limits, - ) + self.return_capture = _capture_context(self.frame, exc_info, retval=retval, limits=probe.limits) + self.duration = duration self.state = SignalState.DONE if probe.evaluate_at != ProbeEvaluateTimingForMethod.ENTER: - self._eval_message(dict(_locals)) + self._eval_message(full_scope) stack = utils.capture_stack(self.frame) # Fix the line number of the top frame. This might have been mangled by # the instrumented exception handling of function probes. + tb = exc_info[2] while tb is not None: frame = tb.tb_frame if frame == self.frame: @@ -206,15 +199,9 @@ def line(self): self.state = SignalState.SKIP_RATE return - self.line_capture = _capture_context( - self.args or _safety.get_args(frame), - _safety.get_locals(frame), - _safety.get_globals(frame), - sys.exc_info(), - limits=probe.limits, - ) + self.line_capture = _capture_context(frame, sys.exc_info(), limits=probe.limits) - self._eval_message(frame.f_locals) + self._eval_message(ChainMap(frame.f_locals, frame.f_globals)) self._stack = utils.capture_stack(frame) diff --git a/ddtrace/debugging/_signal/tracing.py b/ddtrace/debugging/_signal/tracing.py index 94d73015287..226e0230d90 100644 --- a/ddtrace/debugging/_signal/tracing.py +++ b/ddtrace/debugging/_signal/tracing.py @@ -44,7 +44,7 @@ def enter(self) -> None: log.debug("Dynamic span entered with non-span probe: %s", self.probe) return - if not self._eval_condition(dict(self.args) if self.args else {}): + if not self._eval_condition(self.args): return self._span_cm = ddtrace.tracer.trace( @@ -78,7 +78,7 @@ def line(self): class SpanDecoration(LogSignal): """Decorate a span.""" - def _decorate_span(self, _locals: t.Dict[str, t.Any]) -> None: + def _decorate_span(self, scope: t.Mapping[str, t.Any]) -> None: probe = t.cast(SpanDecorationMixin, self.probe) if probe.target_span == SpanDecorationTargetSpan.ACTIVE: @@ -93,7 +93,7 @@ def _decorate_span(self, _locals: t.Dict[str, t.Any]) -> None: log.debug("Decorating span %r according to span decoration probe %r", span, probe) for d in probe.decorations: try: - if not (d.when is None or d.when(_locals)): + if not (d.when is None or d.when(scope)): continue except DDExpressionEvaluationError as e: self.errors.append( @@ -102,7 +102,7 @@ def _decorate_span(self, _locals: t.Dict[str, t.Any]) -> None: continue for tag in d.tags: try: - tag_value = tag.value.render(_locals, serialize) + tag_value = tag.value.render(scope, serialize) except DDExpressionEvaluationError as e: span.set_tag_str( "_dd.di.%s.evaluation_error" % tag.name, ", ".join([serialize(v) for v in e.args]) @@ -117,8 +117,8 @@ def enter(self) -> None: log.debug("Span decoration entered with non-span decoration probe: %s", self.probe) return - if probe.evaluate_at == ProbeEvaluateTimingForMethod.ENTER: - self._decorate_span(dict(self.args) if self.args else {}) + if probe.evaluate_at is ProbeEvaluateTimingForMethod.ENTER: + self._decorate_span(self.args) self.state = SignalState.DONE def exit(self, retval: t.Any, exc_info: ExcInfoType, duration: float) -> None: @@ -128,8 +128,8 @@ def exit(self, retval: t.Any, exc_info: ExcInfoType, duration: float) -> None: log.debug("Span decoration exited with non-span decoration probe: %s", self.probe) return - if probe.evaluate_at == ProbeEvaluateTimingForMethod.EXIT: - self._decorate_span(self._enrich_locals(retval, exc_info, duration)) + if probe.evaluate_at is ProbeEvaluateTimingForMethod.EXIT: + self._decorate_span(self.get_full_scope(retval, exc_info, duration)) self.state = SignalState.DONE def line(self): diff --git a/tests/debugging/signal/test_collector.py b/tests/debugging/signal/test_collector.py index dd39726f7dc..c7b9e752662 100644 --- a/tests/debugging/signal/test_collector.py +++ b/tests/debugging/signal/test_collector.py @@ -36,7 +36,6 @@ def foo(a=42): condition=DDExpression("a not null", lambda _: _["a"] is not None), ), frame=sys._getframe(), - args=[("a", 42)], thread=threading.current_thread(), ) snapshot.line() @@ -53,7 +52,6 @@ def bar(b=None): condition=DDExpression("b not null", lambda _: _["b"] is not None), ), frame=sys._getframe(), - args=[("b", None)], thread=threading.current_thread(), ) snapshot.line() diff --git a/tests/debugging/signal/test_model.py b/tests/debugging/signal/test_model.py index ad58773d0c9..d99500e093f 100644 --- a/tests/debugging/signal/test_model.py +++ b/tests/debugging/signal/test_model.py @@ -7,7 +7,7 @@ def test_enriched_args_locals_globals(): duration = 123456 - _locals = dict( + full_scope = dict( Snapshot( probe=create_log_function_probe( probe_id="test_duration_millis", @@ -18,19 +18,19 @@ def test_enriched_args_locals_globals(): ), frame=sys._getframe(), thread=current_thread(), - )._enrich_locals(None, (None, None, None), duration) + ).get_full_scope(None, (None, None, None), duration) ) # Check for globals - assert "__file__" in _locals + assert "__file__" in full_scope # Check for locals - assert "duration" in _locals + assert "duration" in full_scope def test_duration_millis(): duration = 123456 - _locals = Snapshot( + full_scope = Snapshot( probe=create_log_function_probe( probe_id="test_duration_millis", module="foo", @@ -40,6 +40,6 @@ def test_duration_millis(): ), frame=sys._getframe(), thread=current_thread(), - )._enrich_locals(None, (None, None, None), duration) + ).get_full_scope(None, (None, None, None), duration) - assert _locals["@duration"] == duration / 1e6 + assert full_scope["@duration"] == duration / 1e6 diff --git a/tests/debugging/test_encoding.py b/tests/debugging/test_encoding.py index 0c150da503b..489c996ceda 100644 --- a/tests/debugging/test_encoding.py +++ b/tests/debugging/test_encoding.py @@ -135,15 +135,23 @@ def c(): assert serialized["message"] == "'bad'" +def capture_context(*args, **kwargs): + return _capture_context(sys._getframe(1), *args, **kwargs) + + def test_capture_context_default_level(): - context = _capture_context([("self", tree)], [], [], (None, None, None), CaptureLimits(max_level=0)) - self = context["arguments"]["self"] + def _(self=tree): + return capture_context((None, None, None), limits=CaptureLimits(max_level=0)) + + self = _()["arguments"]["self"] assert self["fields"]["root"]["notCapturedReason"] == "depth" def test_capture_context_one_level(): - context = _capture_context([("self", tree)], [], [], (None, None, None), CaptureLimits(max_level=1)) - self = context["arguments"]["self"] + def _(self=tree): + return capture_context((None, None, None), limits=CaptureLimits(max_level=1)) + + self = _()["arguments"]["self"] assert self["fields"]["root"]["fields"]["left"] == {"notCapturedReason": "depth", "type": "Node"} @@ -152,13 +160,18 @@ def test_capture_context_one_level(): def test_capture_context_two_level(): - context = _capture_context([("self", tree)], [], [], (None, None, None), CaptureLimits(max_level=2)) - self = context["arguments"]["self"] + def _(self=tree): + return capture_context((None, None, None), limits=CaptureLimits(max_level=2)) + + self = _()["arguments"]["self"] assert self["fields"]["root"]["fields"]["left"]["fields"]["right"] == {"notCapturedReason": "depth", "type": "Node"} def test_capture_context_three_level(): - context = _capture_context([("self", tree)], [], [], (None, None, None), CaptureLimits(max_level=3)) + def _(self=tree): + return capture_context((None, None, None), limits=CaptureLimits(max_level=3)) + + context = _() self = context["arguments"]["self"] assert self["fields"]["root"]["fields"]["left"]["fields"]["right"]["fields"]["right"]["isNull"], context assert self["fields"]["root"]["fields"]["left"]["fields"]["right"]["fields"]["left"]["isNull"], context @@ -169,13 +182,15 @@ def test_capture_context_exc(): try: raise Exception("test", "me") except Exception: - context = _capture_context([], [], [], sys.exc_info()) + + def _(): + return capture_context(sys.exc_info()) + + context = _() + exc = context.pop("throwable") - assert context == { - "arguments": {}, - "locals": {}, - "staticFields": {}, - } + assert context["arguments"] == {} + assert context["locals"] == {"@exception": {"type": "Exception", "fields": {}}} assert exc["message"] == "'test', 'me'" assert exc["type"] == "Exception"