Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pymongo): use bytecode wrapping to trace pymongo clients #10516

Merged
merged 23 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions ddtrace/contrib/internal/pymongo/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
import contextlib
import functools
import json
from typing import Iterable

Expand Down Expand Up @@ -45,6 +46,31 @@
_DEFAULT_SERVICE = schematize_service_name("pymongo")


def trace_mongo_client_init(func, args, kwargs):
# Call MongoClient.__init__
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
func(*args, **kwargs)
client = get_argument_value(args, kwargs, 0, "self")
# The MongoClient attempts to trace all of the network
# calls in the trace library. This is good because it measures the
# actual network time. It's bad because it uses a private API which
# could change. We'll see how this goes.
if not isinstance(client._topology, TracedTopology):
client._topology = TracedTopology(client._topology)

def __setddpin__(client, pin):
pin.onto(client._topology)

def __getddpin__(client):
return ddtrace.Pin.get_from(client._topology)

client.__setddpin__ = functools.partial(__setddpin__, client)
client.__getddpin__ = functools.partial(__getddpin__, client)

# Default Pin
ddtrace.Pin(service=_DEFAULT_SERVICE).onto(client)


# TODO: Remove TracedMongoClient when ddtrace.contrib.pymongo.client is removed from the public API.
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
class TracedMongoClient(ObjectProxy):
def __init__(self, client=None, *args, **kwargs):
# To support the former trace_mongo_client interface, we have to keep this old interface
Expand Down Expand Up @@ -88,6 +114,9 @@ def __getddpin__(self):

@contextlib.contextmanager
def wrapped_validate_session(wrapped, instance, args, kwargs):
# The function is exposed in the public API, but it is not used in the codebase.
# TODO: Remove this function when ddtrace.contrib.pymongo.client is removed.
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
mabdinur marked this conversation as resolved.
Show resolved Hide resolved

# We do this to handle a validation `A is B` in pymongo that
# relies on IDs being equal. Since we are proxying objects, we need
# to ensure we're compare proxy with proxy or wrapped with wrapped
Expand Down
50 changes: 23 additions & 27 deletions ddtrace/contrib/internal/pymongo/patch.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import contextlib

import pymongo
from wrapt import wrap_function_wrapper as _w

from ddtrace import Pin
from ddtrace import config
from ddtrace.constants import SPAN_KIND
from ddtrace.constants import SPAN_MEASURED_KEY
from ddtrace.contrib import trace_utils
from ddtrace.contrib.trace_utils import unwrap as _u
from ddtrace.ext import SpanKind
from ddtrace.ext import SpanTypes
from ddtrace.ext import db
from ddtrace.ext import mongo
from ddtrace.internal.constants import COMPONENT
from ddtrace.internal.utils import get_argument_value
from ddtrace.internal.wrapping import unwrap as _u
from ddtrace.internal.wrapping import wrap as _w

from .client import TracedMongoClient
# keep TracedMongoClient import to maintain bakcwards compatibility
mabdinur marked this conversation as resolved.
Show resolved Hide resolved
from .client import TracedMongoClient # noqa: F401
from .client import set_address_tags
from .client import wrapped_validate_session
from .client import trace_mongo_client_init


config._add(
Expand All @@ -31,53 +33,47 @@ def get_version():
return getattr(pymongo, "__version__", "")


# Original Client class
_MongoClient = pymongo.MongoClient

_VERSION = pymongo.version_tuple
_CHECKOUT_FN_NAME = "get_socket" if _VERSION < (4, 5) else "checkout"
_VERIFY_VERSION_CLASS = pymongo.pool.SocketInfo if _VERSION < (4, 5) else pymongo.pool.Connection


def patch():
if getattr(pymongo, "_datadog_patch", False):
return
patch_pymongo_module()
# We should progressively get rid of TracedMongoClient. We now try to
# wrap methods individually. cf #1501
pymongo.MongoClient = TracedMongoClient
_w(pymongo.MongoClient.__init__, trace_mongo_client_init)
pymongo._datadog_patch = True


def unpatch():
if not getattr(pymongo, "_datadog_patch", False):
return
unpatch_pymongo_module()
pymongo.MongoClient = _MongoClient
_u(pymongo.MongoClient, pymongo.MongoClient.__init__)
pymongo._datadog_patch = False


def patch_pymongo_module():
if getattr(pymongo, "_datadog_patch", False):
return
pymongo._datadog_patch = True
Pin().onto(pymongo.server.Server)

# Whenever a pymongo command is invoked, the lib either:
# - Creates a new socket & performs a TCP handshake
# - Grabs a socket already initialized before
_w("pymongo.server", "Server.%s" % _CHECKOUT_FN_NAME, traced_get_socket)
_w("pymongo.pool", f"{_VERIFY_VERSION_CLASS.__name__}.validate_session", wrapped_validate_session)
checkout_fn = getattr(pymongo.server.Server, _CHECKOUT_FN_NAME)
_w(checkout_fn, traced_get_socket)


def unpatch_pymongo_module():
if not getattr(pymongo, "_datadog_patch", False):
return
pymongo._datadog_patch = False

_u(pymongo.server.Server, _CHECKOUT_FN_NAME)
_u(_VERIFY_VERSION_CLASS, "validate_session")
checkout_fn = getattr(pymongo.server.Server, _CHECKOUT_FN_NAME)
_u(checkout_fn, traced_get_socket)


@contextlib.contextmanager
def traced_get_socket(wrapped, instance, args, kwargs):
pin = Pin._find(wrapped, instance)
def traced_get_socket(func, args, kwargs):
instance = get_argument_value(args, kwargs, 0, "self")
pin = Pin._find(func, instance)
if not pin or not pin.enabled():
with wrapped(*args, **kwargs) as sock_info:
with func(*args, **kwargs) as sock_info:
yield sock_info
return

Expand All @@ -92,7 +88,7 @@ def traced_get_socket(wrapped, instance, args, kwargs):
# set span.kind tag equal to type of operation being performed
span.set_tag_str(SPAN_KIND, SpanKind.CLIENT)

with wrapped(*args, **kwargs) as sock_info:
with func(*args, **kwargs) as sock_info:
set_address_tags(span, sock_info.address)
span.set_tag(SPAN_MEASURED_KEY)
yield sock_info
2 changes: 1 addition & 1 deletion tests/contrib/pymongo/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def setUp(self):
# and choose from to perform classic operations. For the sake of our tests,
# let's limit this number to 1
self.client = pymongo.MongoClient(port=MONGO_CONFIG["port"], maxPoolSize=1)
# Override TracedMongoClient's pin's tracer with our dummy tracer
# Override MongoClient's pin's tracer with our dummy tracer
Pin.override(self.client, tracer=self.tracer, service="testdb")

def tearDown(self):
Expand Down
Loading