Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(llmobs): recreate writer on fork #10249

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ddtrace/llmobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ddtrace.llmobs import LLMObs
LLMObs.enable()
"""

from ._llmobs import LLMObs


Expand Down
15 changes: 14 additions & 1 deletion ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ddtrace import patch
from ddtrace.ext import SpanTypes
from ddtrace.internal import atexit
from ddtrace.internal import forksafe
from ddtrace.internal import telemetry
from ddtrace.internal.compat import ensure_text
from ddtrace.internal.logger import get_logger
Expand Down Expand Up @@ -82,11 +83,22 @@ def __init__(self, tracer=None):
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)),
)
self._trace_processor = LLMObsTraceProcessor(self._llmobs_span_writer)
forksafe.register(self._child_after_fork)

def _child_after_fork(self):
self._llmobs_span_writer = self._llmobs_span_writer.recreate()
self._trace_processor._span_writer = self._llmobs_span_writer
self.tracer.configure(settings={"FILTERS": [self._trace_processor]})
try:
self._llmobs_span_writer.start()
except ServiceStatusError:
log.debug("Error starting LLMObs span writer after fork")

def _start_service(self) -> None:
tracer_filters = self.tracer._filters
if not any(isinstance(tracer_filter, LLMObsTraceProcessor) for tracer_filter in tracer_filters):
tracer_filters += [LLMObsTraceProcessor(self._llmobs_span_writer)]
tracer_filters += [self._trace_processor]
self.tracer.configure(settings={"FILTERS": tracer_filters})
try:
self._llmobs_span_writer.start()
Expand All @@ -102,6 +114,7 @@ def _stop_service(self) -> None:
log.debug("Error stopping LLMObs writers")

try:
forksafe.unregister(self._child_after_fork)
self.tracer.shutdown()
except Exception:
log.warning("Failed to shutdown tracer", exc_info=True)
Expand Down
1 change: 1 addition & 0 deletions ddtrace/llmobs/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def recreate(self):
return self.__class__(
interval=self._interval,
timeout=self._timeout,
is_agentless=config._llmobs_agentless_enabled,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
LLM Observability: This fix resolves an issue where LLM Observability spans were not being submitted in forked processes,
such as when using ``celery`` or ``gunicorn`` workers. The LLM Observability writer thread now automatically restarts
when a forked process is detected.
69 changes: 69 additions & 0 deletions tests/llmobs/test_llmobs_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os

import mock
import pytest
Expand All @@ -7,6 +8,7 @@
from ddtrace._trace.context import Context
from ddtrace._trace.span import Span
from ddtrace.ext import SpanTypes
from ddtrace.internal.service import ServiceStatus
from ddtrace.llmobs import LLMObs as llmobs_service
from ddtrace.llmobs._constants import INPUT_DOCUMENTS
from ddtrace.llmobs._constants import INPUT_MESSAGES
Expand Down Expand Up @@ -1342,3 +1344,70 @@ def test_activate_distributed_headers_activates_context(LLMObs, mock_logs):
LLMObs.activate_distributed_headers({})
assert mock_extract.call_count == 1
mock_activate.assert_called_once_with(dummy_context)


def _task(llmobs_service, errors, original_pid, original_span_writer_id):
"""Task in test_llmobs_fork which asserts that LLMObs in a forked process correctly recreates the writer."""
try:
with llmobs_service.workflow():
with llmobs_service.task():
assert llmobs_service._instance.tracer._pid != original_pid
assert id(llmobs_service._instance._llmobs_span_writer) != original_span_writer_id
assert llmobs_service._instance._llmobs_span_writer.enqueue.call_count == 2
assert llmobs_service._instance._llmobs_span_writer._encoder.encode.call_count == 2
except AssertionError as e:
errors.put(e)


def test_llmobs_fork_recreates_and_restarts_writer():
"""Test that forking a process correctly recreates and restarts the LLMObsSpanWriter."""
with mock.patch("ddtrace.internal.writer.HTTPWriter._send_payload"):
llmobs_service.enable(_tracer=DummyTracer(), ml_app="test_app")
original_pid = llmobs_service._instance.tracer._pid
original_span_writer = llmobs_service._instance._llmobs_span_writer
pid = os.fork()
if pid: # parent
assert llmobs_service._instance.tracer._pid == original_pid
assert llmobs_service._instance._llmobs_span_writer == original_span_writer
assert (
llmobs_service._instance._trace_processor._span_writer == llmobs_service._instance._llmobs_span_writer
)
assert llmobs_service._instance._llmobs_span_writer.status == ServiceStatus.RUNNING
else: # child
assert llmobs_service._instance.tracer._pid != original_pid
assert llmobs_service._instance._llmobs_span_writer != original_span_writer
assert (
llmobs_service._instance._trace_processor._span_writer == llmobs_service._instance._llmobs_span_writer
)
assert llmobs_service._instance._llmobs_span_writer.status == ServiceStatus.RUNNING
llmobs_service.disable()
os._exit(12)

_, status = os.waitpid(pid, 0)
exit_code = os.WEXITSTATUS(status)
assert exit_code == 12
llmobs_service.disable()


def test_llmobs_fork_create_span(monkeypatch):
"""Test that forking a process correctly encodes new spans created in each process."""
monkeypatch.setenv("_DD_LLMOBS_WRITER_INTERVAL", 5.0)
with mock.patch("ddtrace.internal.writer.HTTPWriter._send_payload"):
llmobs_service.enable(_tracer=DummyTracer(), ml_app="test_app")
pid = os.fork()
if pid: # parent
with llmobs_service.task():
pass
assert len(llmobs_service._instance._llmobs_span_writer._encoder) == 1
else: # child
with llmobs_service.workflow():
with llmobs_service.task():
pass
assert len(llmobs_service._instance._llmobs_span_writer._encoder) == 2
llmobs_service.disable()
os._exit(12)

_, status = os.waitpid(pid, 0)
exit_code = os.WEXITSTATUS(status)
assert exit_code == 12
llmobs_service.disable()
Loading