Skip to content

Commit

Permalink
Merge branch 'main' into brettlangdon/simplify.flask_simple
Browse files Browse the repository at this point in the history
  • Loading branch information
brettlangdon authored May 3, 2024
2 parents bae7000 + d10e081 commit faea270
Show file tree
Hide file tree
Showing 45 changed files with 1,052 additions and 187 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ ddtrace/appsec/ @DataDog/asm-python
ddtrace/settings/asm.py @DataDog/asm-python
ddtrace/contrib/subprocess/ @DataDog/asm-python
ddtrace/contrib/flask_login/ @DataDog/asm-python
ddtrace/internal/_exceptions.py @DataDog/asm-python
tests/appsec/ @DataDog/asm-python
tests/contrib/dbapi/test_dbapi_appsec.py @DataDog/asm-python
tests/contrib/subprocess @DataDog/asm-python
Expand Down
45 changes: 39 additions & 6 deletions ddtrace/_trace/trace_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import List
from typing import Optional

from ddtrace import config
from ddtrace._trace.span import Span
from ddtrace._trace.utils import extract_DD_context_from_messages
from ddtrace._trace.utils import set_botocore_patched_api_call_span_tags as set_patched_api_call_span_tags
from ddtrace._trace.utils import set_botocore_response_metadata_tags
from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY
from ddtrace.constants import SPAN_KIND
from ddtrace.constants import SPAN_MEASURED_KEY
Expand Down Expand Up @@ -107,6 +108,9 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -
trace_utils.activate_distributed_headers(
tracer, int_config=distributed_headers_config, request_headers=ctx["distributed_headers"]
)
distributed_context = ctx.get_item("distributed_context", traverse=True)
if distributed_context and not call_trace:
span_kwargs["child_of"] = distributed_context
span_kwargs.update(kwargs)
span = (tracer.trace if call_trace else tracer.start_span)(ctx["span_name"], **span_kwargs)
for tk, tv in ctx.get_item("tags", dict()).items():
Expand Down Expand Up @@ -569,20 +573,20 @@ def _on_botocore_patched_api_call_started(ctx):
span.start_ns = start_ns


def _on_botocore_patched_api_call_exception(ctx, response, exception_type, set_response_metadata_tags):
def _on_botocore_patched_api_call_exception(ctx, response, exception_type, is_error_code_fn):
span = ctx.get_item(ctx.get_item("call_key"))
# `ClientError.response` contains the result, so we can still grab response metadata
set_response_metadata_tags(span, response)
set_botocore_response_metadata_tags(span, response, is_error_code_fn=is_error_code_fn)

# If we have a status code, and the status code is not an error,
# then ignore the exception being raised
status_code = span.get_tag(http.STATUS_CODE)
if status_code and not config.botocore.operations[span.resource].is_error_code(int(status_code)):
if status_code and not is_error_code_fn(int(status_code)):
span._ignore_exception(exception_type)


def _on_botocore_patched_api_call_success(ctx, response, set_response_metadata_tags):
set_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response)
def _on_botocore_patched_api_call_success(ctx, response):
set_botocore_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response)


def _on_botocore_trace_context_injection_prepared(
Expand Down Expand Up @@ -682,6 +686,31 @@ def _on_botocore_bedrock_process_response(
span.finish()


def _on_botocore_sqs_recvmessage_post(
ctx: core.ExecutionContext, _, result: Dict, propagate: bool, message_parser: Callable
) -> None:
if result is not None and "Messages" in result and len(result["Messages"]) >= 1:
ctx.set_item("message_received", True)
if propagate:
ctx.set_safe("distributed_context", extract_DD_context_from_messages(result["Messages"], message_parser))


def _on_botocore_kinesis_getrecords_post(
ctx: core.ExecutionContext,
_,
__,
___,
____,
result,
propagate: bool,
message_parser: Callable,
):
if result is not None and "Records" in result and len(result["Records"]) >= 1:
ctx.set_item("message_received", True)
if propagate:
ctx.set_item("distributed_context", extract_DD_context_from_messages(result["Records"], message_parser))


def _on_redis_async_command_post(span, rowcount):
if rowcount is not None:
span.set_metric(db.ROWCOUNT, rowcount)
Expand Down Expand Up @@ -727,10 +756,14 @@ def listen():
core.on("botocore.patched_stepfunctions_api_call.started", _on_botocore_patched_api_call_started)
core.on("botocore.patched_stepfunctions_api_call.exception", _on_botocore_patched_api_call_exception)
core.on("botocore.stepfunctions.update_messages", _on_botocore_update_messages)
core.on("botocore.eventbridge.update_messages", _on_botocore_update_messages)
core.on("botocore.client_context.update_messages", _on_botocore_update_messages)
core.on("botocore.patched_bedrock_api_call.started", _on_botocore_patched_bedrock_api_call_started)
core.on("botocore.patched_bedrock_api_call.exception", _on_botocore_patched_bedrock_api_call_exception)
core.on("botocore.patched_bedrock_api_call.success", _on_botocore_patched_bedrock_api_call_success)
core.on("botocore.bedrock.process_response", _on_botocore_bedrock_process_response)
core.on("botocore.sqs.ReceiveMessage.post", _on_botocore_sqs_recvmessage_post)
core.on("botocore.kinesis.GetRecords.post", _on_botocore_kinesis_getrecords_post)
core.on("redis.async_command.post", _on_redis_async_command_post)

for context_name in (
Expand Down
41 changes: 41 additions & 0 deletions ddtrace/_trace/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

from ddtrace import Span
from ddtrace import config
from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY
from ddtrace.constants import SPAN_KIND
from ddtrace.constants import SPAN_MEASURED_KEY
from ddtrace.ext import SpanKind
from ddtrace.ext import aws
from ddtrace.ext import http
from ddtrace.internal.constants import COMPONENT
from ddtrace.internal.utils.formats import deep_getattr
from ddtrace.propagation.http import HTTPPropagator


def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, endpoint_name, operation):
Expand Down Expand Up @@ -39,3 +46,37 @@ def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params,

# set analytics sample rate
span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.botocore.get_analytics_sample_rate())


def set_botocore_response_metadata_tags(
span: Span, result: Dict[str, Any], is_error_code_fn: Optional[Callable] = None
) -> None:
if not result or not result.get("ResponseMetadata"):
return
response_meta = result["ResponseMetadata"]

if "HTTPStatusCode" in response_meta:
status_code = response_meta["HTTPStatusCode"]
span.set_tag(http.STATUS_CODE, status_code)

# Mark this span as an error if requested
if is_error_code_fn is not None and is_error_code_fn(int(status_code)):
span.error = 1

if "RetryAttempts" in response_meta:
span.set_tag("retry_attempts", response_meta["RetryAttempts"])

if "RequestId" in response_meta:
span.set_tag_str("aws.requestid", response_meta["RequestId"])


def extract_DD_context_from_messages(messages, extract_from_message: Callable):
ctx = None
if len(messages) >= 1:
message = messages[0]
context_json = extract_from_message(message)
if context_json is not None:
child_of = HTTPPropagator.extract(context_json)
if child_of.trace_id is not None:
ctx = child_of
return ctx
8 changes: 8 additions & 0 deletions ddtrace/appsec/_asm_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ddtrace.appsec._iast._utils import _is_iast_enabled
from ddtrace.appsec._utils import get_triggers
from ddtrace.internal import core
from ddtrace.internal._exceptions import BlockingException
from ddtrace.internal.constants import REQUEST_PATH_PARAMS
from ddtrace.internal.logger import get_logger
from ddtrace.settings.asm import config as asm_config
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(self):
env = ASM_Environment(True)

self._id = _DataHandler.main_id
self._root = not in_context()
self.active = True
self.execution_context = core.ExecutionContext(__name__, **{"asm_env": env})

Expand Down Expand Up @@ -393,6 +395,12 @@ def asm_request_context_manager(
if resources is not None:
try:
yield resources
except BlockingException as e:
# ensure that the BlockingRequest that is never raised outside a context
# is also never propagated outside the context
core.set_item(WAF_CONTEXT_NAMES.BLOCKED, e.args[0])
if not resources._root:
raise
finally:
_end_context(resources)
else:
Expand Down
16 changes: 13 additions & 3 deletions ddtrace/appsec/_common_module_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Callable
from typing import Dict

from ddtrace.appsec._constants import WAF_CONTEXT_NAMES
from ddtrace.internal import core
from ddtrace.internal._exceptions import BlockingException
from ddtrace.internal.logger import get_logger
from ddtrace.settings.asm import config as asm_config
from ddtrace.vendor.wrapt import FunctionWrapper
Expand Down Expand Up @@ -49,6 +51,7 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs
try:
from ddtrace.appsec._asm_request_context import call_waf_callback
from ddtrace.appsec._asm_request_context import in_context
from ddtrace.appsec._asm_request_context import is_blocked
from ddtrace.appsec._constants import EXPLOIT_PREVENTION
except ImportError:
# open is used during module initialization
Expand All @@ -66,7 +69,9 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs
crop_trace="wrapped_open_CFDDB7ABBA9081B6",
rule_type=EXPLOIT_PREVENTION.TYPE.LFI,
)
# DEV: Next part of the exploit prevention feature: add block here
if is_blocked():
raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "lfi", filename)

return original_open_callable(*args, **kwargs)


Expand All @@ -82,6 +87,7 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs
try:
from ddtrace.appsec._asm_request_context import call_waf_callback
from ddtrace.appsec._asm_request_context import in_context
from ddtrace.appsec._asm_request_context import is_blocked
from ddtrace.appsec._constants import EXPLOIT_PREVENTION
except ImportError:
# open is used during module initialization
Expand All @@ -98,7 +104,8 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs
crop_trace="wrapped_open_ED4CF71136E15EBF",
rule_type=EXPLOIT_PREVENTION.TYPE.SSRF,
)
# DEV: Next part of the exploit prevention feature: add block here
if is_blocked():
raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url)
return original_open_callable(*args, **kwargs)


Expand All @@ -115,6 +122,7 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args,
try:
from ddtrace.appsec._asm_request_context import call_waf_callback
from ddtrace.appsec._asm_request_context import in_context
from ddtrace.appsec._asm_request_context import is_blocked
from ddtrace.appsec._constants import EXPLOIT_PREVENTION
except ImportError:
# open is used during module initialization
Expand All @@ -129,7 +137,9 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args,
crop_trace="wrapped_request_D8CB81E472AF98A2",
rule_type=EXPLOIT_PREVENTION.TYPE.SSRF,
)
# DEV: Next part of the exploit prevention feature: add block here
if is_blocked():
raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url)

return original_request_callable(*args, **kwargs)


Expand Down
4 changes: 4 additions & 0 deletions ddtrace/contrib/asgi/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ddtrace.ext import SpanKind
from ddtrace.ext import SpanTypes
from ddtrace.ext import http
from ddtrace.internal._exceptions import BlockingException
from ddtrace.internal.compat import is_valid_ip
from ddtrace.internal.constants import COMPONENT
from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED
Expand Down Expand Up @@ -288,6 +289,9 @@ async def wrapped_blocked_send(message):
try:
core.dispatch("asgi.start_request", ("asgi",))
return await self.app(scope, receive, wrapped_send)
except BlockingException as e:
core.set_item(HTTP_REQUEST_BLOCKED, e.args[0])
return await _blocked_asgi_app(scope, receive, wrapped_blocked_send)
except trace_utils.InterruptException:
return await _blocked_asgi_app(scope, receive, wrapped_blocked_send)
except Exception as exc:
Expand Down
18 changes: 11 additions & 7 deletions ddtrace/contrib/botocore/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@
from .services.sqs import update_messages as inject_trace_to_sqs_or_sns_message
from .services.stepfunctions import patched_stepfunction_api_call
from .services.stepfunctions import update_stepfunction_input
from .utils import inject_trace_to_client_context
from .utils import inject_trace_to_eventbridge_detail
from .utils import set_response_metadata_tags
from .utils import update_client_context
from .utils import update_eventbridge_detail


_PATCHED_SUBMODULES = set() # type: Set[str]
Expand Down Expand Up @@ -175,11 +174,11 @@ def prep_context_injection(ctx, endpoint_name, operation, trace_operation, param
schematization_function = schematize_cloud_messaging_operation

if endpoint_name == "lambda" and operation == "Invoke":
injection_function = inject_trace_to_client_context
injection_function = update_client_context
schematization_function = schematize_cloud_faas_operation
cloud_service = "lambda"
if endpoint_name == "events" and operation == "PutEvents":
injection_function = inject_trace_to_eventbridge_detail
injection_function = update_eventbridge_detail
cloud_service = "events"
if endpoint_name == "sns" and "Publish" in operation:
injection_function = inject_trace_to_sqs_or_sns_message
Expand Down Expand Up @@ -224,9 +223,14 @@ def patched_api_call_fallback(original_func, instance, args, kwargs, function_va
except botocore.exceptions.ClientError as e:
core.dispatch(
"botocore.patched_api_call.exception",
[ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags],
[
ctx,
e.response,
botocore.exceptions.ClientError,
config.botocore.operations[ctx["instrumented_api_call"].resource].is_error_code,
],
)
raise
else:
core.dispatch("botocore.patched_api_call.success", [ctx, result, set_response_metadata_tags])
core.dispatch("botocore.patched_api_call.success", [ctx, result])
return result
Loading

0 comments on commit faea270

Please sign in to comment.