diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 26d1727a297..8d1bf16498e 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -22,6 +22,7 @@ from ddtrace.appsec._iast._utils import _is_iast_enabled from ddtrace.appsec._utils import get_triggers from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException # from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.constants import REQUEST_PATH_PARAMS @@ -60,6 +61,8 @@ class ASM_Environment: def __init__(self, span: Optional[Span] = None): self.root = not in_context() + if self.root: + core.add_suppress_exception(BlockingException) if span is None: self.span: Span = core.get_item("call") else: diff --git a/ddtrace/internal/core/__init__.py b/ddtrace/internal/core/__init__.py index f31e7d0cb7d..a196d8b3689 100644 --- a/ddtrace/internal/core/__init__.py +++ b/ddtrace/internal/core/__init__.py @@ -101,7 +101,7 @@ def _on_jsonify_context_started_flask(ctx): The names of these events follow the pattern ``context.[started|ended].``. """ -from contextlib import contextmanager +from contextlib import AbstractContextManager import logging import sys from typing import TYPE_CHECKING # noqa:F401 @@ -163,8 +163,8 @@ def _deprecate_span_kwarg(span): ) -class ExecutionContext: - __slots__ = ["identifier", "_data", "_parents", "_span", "_token"] +class ExecutionContext(AbstractContextManager): + __slots__ = ["identifier", "_data", "_parents", "_span", "_token", "_suppress_exceptions"] def __init__(self, identifier, parent=None, span=None, **kwargs): _deprecate_span_kwarg(span) @@ -172,10 +172,12 @@ def __init__(self, identifier, parent=None, span=None, **kwargs): self._data = {} self._parents = [] self._span = span + self._suppress_exceptions = [] if parent is not None: self.addParent(parent) self._data.update(kwargs) + def __enter__(self): if self._span is None and "_CURRENT_CONTEXT" in globals(): self._token = _CURRENT_CONTEXT.set(self) dispatch("context.started.%s" % self.identifier, (self,)) @@ -192,8 +194,8 @@ def parents(self): def parent(self): return self._parents[0] if self._parents else None - def end(self): - dispatch_result = dispatch_with_results("context.ended.%s" % self.identifier, (self,)) + def __exit__(self, exc_type, exc_value, traceback): + dispatch("context.ended.%s" % self.identifier, (self,)) if self._span is None: try: _CURRENT_CONTEXT.reset(self._token) @@ -209,22 +211,18 @@ def end(self): ) if id(self) in DEPRECATION_MEMO: DEPRECATION_MEMO.remove(id(self)) - return dispatch_result + + return ( + True + if exc_type is None + else any(issubclass(exc_type, exc_type_) for exc_type_ in self._suppress_exceptions) + ) def addParent(self, context): if self.identifier == ROOT_CONTEXT_ID: raise ValueError("Cannot add parent to root context") self._parents.append(context) - @classmethod - @contextmanager - def context_with_data(cls, identifier, parent=None, span=None, **kwargs): - new_context = cls(identifier, parent=parent, span=span, **kwargs) - try: - yield new_context - finally: - new_context.end() - def get_item(current, data_key: str, default: Optional[Any] = None) -> Any: # NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance while current is not None: @@ -294,15 +292,18 @@ def _reset_context(): def context_with_data(identifier, parent=None, **kwargs): - return _CONTEXT_CLASS.context_with_data(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs) + return _CONTEXT_CLASS(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs) + + +def add_suppress_exception(exc_type: type) -> None: + _CURRENT_CONTEXT.get()._suppress_exceptions.append(exc_type) def get_item(data_key: str, span: Optional["Span"] = None) -> Any: _deprecate_span_kwarg(span) if span is not None and span._local_root is not None: return span._local_root._get_ctx_item(data_key) - else: - return _CURRENT_CONTEXT.get().get_item(data_key) + return _CURRENT_CONTEXT.get().get_item(data_key) def get_local_item(data_key: str, span: Optional["Span"] = None) -> Any: @@ -313,8 +314,7 @@ def get_items(data_keys: List[str], span: Optional["Span"] = None) -> List[Optio _deprecate_span_kwarg(span) if span is not None and span._local_root is not None: return [span._local_root._get_ctx_item(key) for key in data_keys] - else: - return _CURRENT_CONTEXT.get().get_items(data_keys) + return _CURRENT_CONTEXT.get().get_items(data_keys) def set_safe(data_key: str, data_value: Optional[Any]) -> None: diff --git a/tests/appsec/appsec/test_asm_request_context.py b/tests/appsec/appsec/test_asm_request_context.py index 487401f00ed..b5facd7f70d 100644 --- a/tests/appsec/appsec/test_asm_request_context.py +++ b/tests/appsec/appsec/test_asm_request_context.py @@ -2,60 +2,68 @@ from ddtrace.appsec import _asm_request_context from ddtrace.internal._exceptions import BlockingException -from tests.utils import override_global_config +from tests.appsec.utils import asm_context _TEST_IP = "1.2.3.4" _TEST_HEADERS = {"foo": "bar"} +config_asm = {"_asm_enabled": True} + def test_context_set_and_reset(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: True): - assert _asm_request_context.get_ip() == _TEST_IP - assert _asm_request_context.get_headers() == _TEST_HEADERS - assert _asm_request_context.get_headers_case_sensitive() - assert _asm_request_context.get_value("callbacks", "block") is not None - assert _asm_request_context.get_ip() is None - assert _asm_request_context.get_headers() == {} - assert _asm_request_context.get_value("callbacks", "block") is None + with asm_context( + ip_addr=_TEST_IP, + headers=_TEST_HEADERS, + headers_case_sensitive=True, + block_request_callable=(lambda: True), + config=config_asm, + ): + # with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: True): + assert _asm_request_context.get_ip() == _TEST_IP + assert _asm_request_context.get_headers() == _TEST_HEADERS + assert _asm_request_context.get_headers_case_sensitive() + assert _asm_request_context.get_value("callbacks", "block") is not None + assert _asm_request_context.get_ip() is None + assert _asm_request_context.get_headers() == {} + assert _asm_request_context.get_value("callbacks", "block") is None + assert not _asm_request_context.get_headers_case_sensitive() + with asm_context( + ip_addr=_TEST_IP, + headers=_TEST_HEADERS, + config=config_asm, + ): assert not _asm_request_context.get_headers_case_sensitive() - with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS): - assert not _asm_request_context.get_headers_case_sensitive() - assert not _asm_request_context.block_request() + assert not _asm_request_context.block_request() def test_set_get_ip(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_ip(_TEST_IP) - assert _asm_request_context.get_ip() == _TEST_IP + with asm_context(config=config_asm): + _asm_request_context.set_ip(_TEST_IP) + assert _asm_request_context.get_ip() == _TEST_IP def test_set_get_headers(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_headers(_TEST_HEADERS) - assert _asm_request_context.get_headers() == _TEST_HEADERS + with asm_context(config=config_asm): + _asm_request_context.set_headers(_TEST_HEADERS) + assert _asm_request_context.get_headers() == _TEST_HEADERS def test_call_block_callable_none(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_block_request_callable(None) - assert not _asm_request_context.block_request() + with asm_context(config=config_asm): + _asm_request_context.set_block_request_callable(None) assert not _asm_request_context.block_request() + assert not _asm_request_context.block_request() def test_call_block_callable_noargs(): def _callable(): return 42 - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_block_request_callable(_callable) - assert _asm_request_context.get_value("callbacks", "block")() == 42 - assert not _asm_request_context.get_value("callbacks", "block") + with asm_context(config=config_asm): + _asm_request_context.set_block_request_callable(_callable) + assert _asm_request_context.get_value("callbacks", "block")() == 42 + assert not _asm_request_context.get_value("callbacks", "block") def test_call_block_callable_curried(): @@ -65,31 +73,34 @@ class TestException(Exception): def _callable(): raise TestException() - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_block_request_callable(_callable) - with pytest.raises(TestException): - assert _asm_request_context.block_request() + with asm_context(config=config_asm): + _asm_request_context.set_block_request_callable(_callable) + with pytest.raises(TestException): + assert _asm_request_context.block_request() def test_set_get_headers_case_sensitive(): # default reset value should be False assert not _asm_request_context.get_headers_case_sensitive() - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_headers_case_sensitive(True) - assert _asm_request_context.get_headers_case_sensitive() - _asm_request_context.set_headers_case_sensitive(False) - assert not _asm_request_context.get_headers_case_sensitive() + with asm_context(config=config_asm): + _asm_request_context.set_headers_case_sensitive(True) + assert _asm_request_context.get_headers_case_sensitive() + _asm_request_context.set_headers_case_sensitive(False) + assert not _asm_request_context.get_headers_case_sensitive() def test_asm_request_context_manager(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: 42): - assert _asm_request_context.get_ip() == _TEST_IP - assert _asm_request_context.get_headers() == _TEST_HEADERS - assert _asm_request_context.get_headers_case_sensitive() - assert _asm_request_context.get_value("callbacks", "block")() == 42 + with asm_context( + ip_addr=_TEST_IP, + headers=_TEST_HEADERS, + headers_case_sensitive=True, + block_request_callable=(lambda: 42), + config=config_asm, + ): + assert _asm_request_context.get_ip() == _TEST_IP + assert _asm_request_context.get_headers() == _TEST_HEADERS + assert _asm_request_context.get_headers_case_sensitive() + assert _asm_request_context.get_value("callbacks", "block")() == 42 assert _asm_request_context.get_ip() is None assert _asm_request_context.get_headers() == {} @@ -98,16 +109,15 @@ def test_asm_request_context_manager(): def test_blocking_exception_correctly_propagated(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - witness = 0 - with _asm_request_context.asm_request_context_manager(): - witness = 1 - raise BlockingException({}, "rule", "type", "value") - # should be skipped by exception - witness = 3 - # should be also skipped by exception - witness = 4 - # no more exception there - # ensure that the exception was raised and caught at the end of the last context manager - assert witness == 1 + with asm_context(config=config_asm): + witness = 0 + with asm_context(config=config_asm): + witness = 1 + raise BlockingException({}, "rule", "type", "value") + # should be skipped by exception + witness = 3 + # should be also skipped by exception + witness = 4 + # no more exception there + # ensure that the exception was raised and caught at the end of the last context manager + assert witness == 1 diff --git a/tests/appsec/utils.py b/tests/appsec/utils.py index aee58cc3479..b4c010d1123 100644 --- a/tests/appsec/utils.py +++ b/tests/appsec/utils.py @@ -25,7 +25,9 @@ def asm_context( tracer=None, span_name: str = "", ip_addr: typing.Optional[str] = None, + headers_case_sensitive: bool = False, headers: typing.Optional[typing.Dict[str, str]] = None, + block_request_callable: typing.Optional[typing.Callable[[], bool]] = None, service: typing.Optional[str] = None, config=None, ): @@ -40,7 +42,9 @@ def asm_context( with core.context_with_data( "test.asm", remote_addr=ip_addr, + headers_case_sensitive=headers_case_sensitive, headers=headers, + block_request_callable=block_request_callable, service=service, ), tracer.trace(span_name or "test", span_type=SpanTypes.WEB, service=service) as span: yield span