Skip to content

Commit

Permalink
make core context real context and add exception catching. fix test_a…
Browse files Browse the repository at this point in the history
…sm_request_context
  • Loading branch information
christophe-papazian committed Oct 7, 2024
1 parent 4203c4f commit 7d7f3c5
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 80 deletions.
3 changes: 3 additions & 0 deletions ddtrace/appsec/_asm_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ddtrace.appsec._iast._utils import _is_iast_enabled
from ddtrace.appsec._utils import get_triggers
from ddtrace.internal import core
from ddtrace.internal._exceptions import BlockingException

# from ddtrace.internal._exceptions import BlockingException
from ddtrace.internal.constants import REQUEST_PATH_PARAMS
Expand Down Expand Up @@ -60,6 +61,8 @@ class ASM_Environment:

def __init__(self, span: Optional[Span] = None):
self.root = not in_context()
if self.root:
core.add_suppress_exception(BlockingException)
if span is None:
self.span: Span = core.get_item("call")
else:
Expand Down
40 changes: 20 additions & 20 deletions ddtrace/internal/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _on_jsonify_context_started_flask(ctx):
The names of these events follow the pattern ``context.[started|ended].<context_name>``.
"""

from contextlib import contextmanager
from contextlib import AbstractContextManager
import logging
import sys
from typing import TYPE_CHECKING # noqa:F401
Expand Down Expand Up @@ -163,19 +163,21 @@ def _deprecate_span_kwarg(span):
)


class ExecutionContext:
__slots__ = ["identifier", "_data", "_parents", "_span", "_token"]
class ExecutionContext(AbstractContextManager):
__slots__ = ["identifier", "_data", "_parents", "_span", "_token", "_suppress_exceptions"]

def __init__(self, identifier, parent=None, span=None, **kwargs):
_deprecate_span_kwarg(span)
self.identifier = identifier
self._data = {}
self._parents = []
self._span = span
self._suppress_exceptions = []
if parent is not None:
self.addParent(parent)
self._data.update(kwargs)

def __enter__(self):
if self._span is None and "_CURRENT_CONTEXT" in globals():
self._token = _CURRENT_CONTEXT.set(self)
dispatch("context.started.%s" % self.identifier, (self,))
Expand All @@ -192,8 +194,8 @@ def parents(self):
def parent(self):
return self._parents[0] if self._parents else None

def end(self):
dispatch_result = dispatch_with_results("context.ended.%s" % self.identifier, (self,))
def __exit__(self, exc_type, exc_value, traceback):
dispatch("context.ended.%s" % self.identifier, (self,))
if self._span is None:
try:
_CURRENT_CONTEXT.reset(self._token)
Expand All @@ -209,22 +211,18 @@ def end(self):
)
if id(self) in DEPRECATION_MEMO:
DEPRECATION_MEMO.remove(id(self))
return dispatch_result

return (
True
if exc_type is None
else any(issubclass(exc_type, exc_type_) for exc_type_ in self._suppress_exceptions)
)

def addParent(self, context):
if self.identifier == ROOT_CONTEXT_ID:
raise ValueError("Cannot add parent to root context")
self._parents.append(context)

@classmethod
@contextmanager
def context_with_data(cls, identifier, parent=None, span=None, **kwargs):
new_context = cls(identifier, parent=parent, span=span, **kwargs)
try:
yield new_context
finally:
new_context.end()

def get_item(current, data_key: str, default: Optional[Any] = None) -> Any:
# NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance
while current is not None:
Expand Down Expand Up @@ -294,15 +292,18 @@ def _reset_context():


def context_with_data(identifier, parent=None, **kwargs):
return _CONTEXT_CLASS.context_with_data(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs)
return _CONTEXT_CLASS(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs)


def add_suppress_exception(exc_type: type) -> None:
_CURRENT_CONTEXT.get()._suppress_exceptions.append(exc_type)


def get_item(data_key: str, span: Optional["Span"] = None) -> Any:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
return span._local_root._get_ctx_item(data_key)
else:
return _CURRENT_CONTEXT.get().get_item(data_key)
return _CURRENT_CONTEXT.get().get_item(data_key)


def get_local_item(data_key: str, span: Optional["Span"] = None) -> Any:
Expand All @@ -313,8 +314,7 @@ def get_items(data_keys: List[str], span: Optional["Span"] = None) -> List[Optio
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
return [span._local_root._get_ctx_item(key) for key in data_keys]
else:
return _CURRENT_CONTEXT.get().get_items(data_keys)
return _CURRENT_CONTEXT.get().get_items(data_keys)


def set_safe(data_key: str, data_value: Optional[Any]) -> None:
Expand Down
130 changes: 70 additions & 60 deletions tests/appsec/appsec/test_asm_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,68 @@

from ddtrace.appsec import _asm_request_context
from ddtrace.internal._exceptions import BlockingException
from tests.utils import override_global_config
from tests.appsec.utils import asm_context


_TEST_IP = "1.2.3.4"
_TEST_HEADERS = {"foo": "bar"}

config_asm = {"_asm_enabled": True}


def test_context_set_and_reset():
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: True):
assert _asm_request_context.get_ip() == _TEST_IP
assert _asm_request_context.get_headers() == _TEST_HEADERS
assert _asm_request_context.get_headers_case_sensitive()
assert _asm_request_context.get_value("callbacks", "block") is not None
assert _asm_request_context.get_ip() is None
assert _asm_request_context.get_headers() == {}
assert _asm_request_context.get_value("callbacks", "block") is None
with asm_context(
ip_addr=_TEST_IP,
headers=_TEST_HEADERS,
headers_case_sensitive=True,
block_request_callable=(lambda: True),
config=config_asm,
):
# with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: True):
assert _asm_request_context.get_ip() == _TEST_IP
assert _asm_request_context.get_headers() == _TEST_HEADERS
assert _asm_request_context.get_headers_case_sensitive()
assert _asm_request_context.get_value("callbacks", "block") is not None
assert _asm_request_context.get_ip() is None
assert _asm_request_context.get_headers() == {}
assert _asm_request_context.get_value("callbacks", "block") is None
assert not _asm_request_context.get_headers_case_sensitive()
with asm_context(
ip_addr=_TEST_IP,
headers=_TEST_HEADERS,
config=config_asm,
):
assert not _asm_request_context.get_headers_case_sensitive()
with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS):
assert not _asm_request_context.get_headers_case_sensitive()
assert not _asm_request_context.block_request()
assert not _asm_request_context.block_request()


def test_set_get_ip():
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
_asm_request_context.set_ip(_TEST_IP)
assert _asm_request_context.get_ip() == _TEST_IP
with asm_context(config=config_asm):
_asm_request_context.set_ip(_TEST_IP)
assert _asm_request_context.get_ip() == _TEST_IP


def test_set_get_headers():
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
_asm_request_context.set_headers(_TEST_HEADERS)
assert _asm_request_context.get_headers() == _TEST_HEADERS
with asm_context(config=config_asm):
_asm_request_context.set_headers(_TEST_HEADERS)
assert _asm_request_context.get_headers() == _TEST_HEADERS


def test_call_block_callable_none():
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
_asm_request_context.set_block_request_callable(None)
assert not _asm_request_context.block_request()
with asm_context(config=config_asm):
_asm_request_context.set_block_request_callable(None)
assert not _asm_request_context.block_request()
assert not _asm_request_context.block_request()


def test_call_block_callable_noargs():
def _callable():
return 42

with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
_asm_request_context.set_block_request_callable(_callable)
assert _asm_request_context.get_value("callbacks", "block")() == 42
assert not _asm_request_context.get_value("callbacks", "block")
with asm_context(config=config_asm):
_asm_request_context.set_block_request_callable(_callable)
assert _asm_request_context.get_value("callbacks", "block")() == 42
assert not _asm_request_context.get_value("callbacks", "block")


def test_call_block_callable_curried():
Expand All @@ -65,31 +73,34 @@ class TestException(Exception):
def _callable():
raise TestException()

with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
_asm_request_context.set_block_request_callable(_callable)
with pytest.raises(TestException):
assert _asm_request_context.block_request()
with asm_context(config=config_asm):
_asm_request_context.set_block_request_callable(_callable)
with pytest.raises(TestException):
assert _asm_request_context.block_request()


def test_set_get_headers_case_sensitive():
# default reset value should be False
assert not _asm_request_context.get_headers_case_sensitive()
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
_asm_request_context.set_headers_case_sensitive(True)
assert _asm_request_context.get_headers_case_sensitive()
_asm_request_context.set_headers_case_sensitive(False)
assert not _asm_request_context.get_headers_case_sensitive()
with asm_context(config=config_asm):
_asm_request_context.set_headers_case_sensitive(True)
assert _asm_request_context.get_headers_case_sensitive()
_asm_request_context.set_headers_case_sensitive(False)
assert not _asm_request_context.get_headers_case_sensitive()


def test_asm_request_context_manager():
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: 42):
assert _asm_request_context.get_ip() == _TEST_IP
assert _asm_request_context.get_headers() == _TEST_HEADERS
assert _asm_request_context.get_headers_case_sensitive()
assert _asm_request_context.get_value("callbacks", "block")() == 42
with asm_context(
ip_addr=_TEST_IP,
headers=_TEST_HEADERS,
headers_case_sensitive=True,
block_request_callable=(lambda: 42),
config=config_asm,
):
assert _asm_request_context.get_ip() == _TEST_IP
assert _asm_request_context.get_headers() == _TEST_HEADERS
assert _asm_request_context.get_headers_case_sensitive()
assert _asm_request_context.get_value("callbacks", "block")() == 42

assert _asm_request_context.get_ip() is None
assert _asm_request_context.get_headers() == {}
Expand All @@ -98,16 +109,15 @@ def test_asm_request_context_manager():


def test_blocking_exception_correctly_propagated():
with override_global_config({"_asm_enabled": True}):
with _asm_request_context.asm_request_context_manager():
witness = 0
with _asm_request_context.asm_request_context_manager():
witness = 1
raise BlockingException({}, "rule", "type", "value")
# should be skipped by exception
witness = 3
# should be also skipped by exception
witness = 4
# no more exception there
# ensure that the exception was raised and caught at the end of the last context manager
assert witness == 1
with asm_context(config=config_asm):
witness = 0
with asm_context(config=config_asm):
witness = 1
raise BlockingException({}, "rule", "type", "value")
# should be skipped by exception
witness = 3
# should be also skipped by exception
witness = 4
# no more exception there
# ensure that the exception was raised and caught at the end of the last context manager
assert witness == 1
4 changes: 4 additions & 0 deletions tests/appsec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def asm_context(
tracer=None,
span_name: str = "",
ip_addr: typing.Optional[str] = None,
headers_case_sensitive: bool = False,
headers: typing.Optional[typing.Dict[str, str]] = None,
block_request_callable: typing.Optional[typing.Callable[[], bool]] = None,
service: typing.Optional[str] = None,
config=None,
):
Expand All @@ -40,7 +42,9 @@ def asm_context(
with core.context_with_data(
"test.asm",
remote_addr=ip_addr,
headers_case_sensitive=headers_case_sensitive,
headers=headers,
block_request_callable=block_request_callable,
service=service,
), tracer.trace(span_name or "test", span_type=SpanTypes.WEB, service=service) as span:
yield span

0 comments on commit 7d7f3c5

Please sign in to comment.