diff --git a/ddtrace/llmobs/__init__.py b/ddtrace/llmobs/__init__.py index 2a8ad0cb41c..11100d3ed66 100644 --- a/ddtrace/llmobs/__init__.py +++ b/ddtrace/llmobs/__init__.py @@ -5,6 +5,7 @@ from ddtrace.llmobs import LLMObs LLMObs.enable() """ + from ._llmobs import LLMObs diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 08dc7b83968..262e99a72e6 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -11,6 +11,7 @@ from ddtrace import patch from ddtrace.ext import SpanTypes from ddtrace.internal import atexit +from ddtrace.internal import forksafe from ddtrace.internal import telemetry from ddtrace.internal.compat import ensure_text from ddtrace.internal.logger import get_logger @@ -82,11 +83,22 @@ def __init__(self, tracer=None): interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)), timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)), ) + self._trace_processor = LLMObsTraceProcessor(self._llmobs_span_writer) + forksafe.register(self._child_after_fork) + + def _child_after_fork(self): + self._llmobs_span_writer = self._llmobs_span_writer.recreate() + self._trace_processor._span_writer = self._llmobs_span_writer + self.tracer.configure(settings={"FILTERS": [self._trace_processor]}) + try: + self._llmobs_span_writer.start() + except ServiceStatusError: + log.debug("Error starting LLMObs span writer after fork") def _start_service(self) -> None: tracer_filters = self.tracer._filters if not any(isinstance(tracer_filter, LLMObsTraceProcessor) for tracer_filter in tracer_filters): - tracer_filters += [LLMObsTraceProcessor(self._llmobs_span_writer)] + tracer_filters += [self._trace_processor] self.tracer.configure(settings={"FILTERS": tracer_filters}) try: self._llmobs_span_writer.start() @@ -102,6 +114,7 @@ def _stop_service(self) -> None: log.debug("Error stopping LLMObs writers") try: + forksafe.unregister(self._child_after_fork) self.tracer.shutdown() except Exception: log.warning("Failed to shutdown tracer", exc_info=True) diff --git a/ddtrace/llmobs/_writer.py b/ddtrace/llmobs/_writer.py index bed17d5d19e..abcf933a1f0 100644 --- a/ddtrace/llmobs/_writer.py +++ b/ddtrace/llmobs/_writer.py @@ -292,6 +292,7 @@ def recreate(self): return self.__class__( interval=self._interval, timeout=self._timeout, + is_agentless=config._llmobs_agentless_enabled, ) diff --git a/releasenotes/notes/fix-llmobs-forked-writer-257b993bcf131af8.yaml b/releasenotes/notes/fix-llmobs-forked-writer-257b993bcf131af8.yaml new file mode 100644 index 00000000000..4d5ce7dc305 --- /dev/null +++ b/releasenotes/notes/fix-llmobs-forked-writer-257b993bcf131af8.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + LLM Observability: This fix resolves an issue where LLM Observability spans were not being submitted in forked processes, + such as when using ``celery`` or ``gunicorn`` workers. The LLM Observability writer thread now automatically restarts + when a forked process is detected. \ No newline at end of file diff --git a/tests/llmobs/test_llmobs_service.py b/tests/llmobs/test_llmobs_service.py index 74011e1c2e0..8f302fb1c05 100644 --- a/tests/llmobs/test_llmobs_service.py +++ b/tests/llmobs/test_llmobs_service.py @@ -1,4 +1,5 @@ import json +import os import mock import pytest @@ -7,6 +8,7 @@ from ddtrace._trace.context import Context from ddtrace._trace.span import Span from ddtrace.ext import SpanTypes +from ddtrace.internal.service import ServiceStatus from ddtrace.llmobs import LLMObs as llmobs_service from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES @@ -1342,3 +1344,70 @@ def test_activate_distributed_headers_activates_context(LLMObs, mock_logs): LLMObs.activate_distributed_headers({}) assert mock_extract.call_count == 1 mock_activate.assert_called_once_with(dummy_context) + + +def _task(llmobs_service, errors, original_pid, original_span_writer_id): + """Task in test_llmobs_fork which asserts that LLMObs in a forked process correctly recreates the writer.""" + try: + with llmobs_service.workflow(): + with llmobs_service.task(): + assert llmobs_service._instance.tracer._pid != original_pid + assert id(llmobs_service._instance._llmobs_span_writer) != original_span_writer_id + assert llmobs_service._instance._llmobs_span_writer.enqueue.call_count == 2 + assert llmobs_service._instance._llmobs_span_writer._encoder.encode.call_count == 2 + except AssertionError as e: + errors.put(e) + + +def test_llmobs_fork_recreates_and_restarts_writer(): + """Test that forking a process correctly recreates and restarts the LLMObsSpanWriter.""" + with mock.patch("ddtrace.internal.writer.HTTPWriter._send_payload"): + llmobs_service.enable(_tracer=DummyTracer(), ml_app="test_app") + original_pid = llmobs_service._instance.tracer._pid + original_span_writer = llmobs_service._instance._llmobs_span_writer + pid = os.fork() + if pid: # parent + assert llmobs_service._instance.tracer._pid == original_pid + assert llmobs_service._instance._llmobs_span_writer == original_span_writer + assert ( + llmobs_service._instance._trace_processor._span_writer == llmobs_service._instance._llmobs_span_writer + ) + assert llmobs_service._instance._llmobs_span_writer.status == ServiceStatus.RUNNING + else: # child + assert llmobs_service._instance.tracer._pid != original_pid + assert llmobs_service._instance._llmobs_span_writer != original_span_writer + assert ( + llmobs_service._instance._trace_processor._span_writer == llmobs_service._instance._llmobs_span_writer + ) + assert llmobs_service._instance._llmobs_span_writer.status == ServiceStatus.RUNNING + llmobs_service.disable() + os._exit(12) + + _, status = os.waitpid(pid, 0) + exit_code = os.WEXITSTATUS(status) + assert exit_code == 12 + llmobs_service.disable() + + +def test_llmobs_fork_create_span(monkeypatch): + """Test that forking a process correctly encodes new spans created in each process.""" + monkeypatch.setenv("_DD_LLMOBS_WRITER_INTERVAL", 5.0) + with mock.patch("ddtrace.internal.writer.HTTPWriter._send_payload"): + llmobs_service.enable(_tracer=DummyTracer(), ml_app="test_app") + pid = os.fork() + if pid: # parent + with llmobs_service.task(): + pass + assert len(llmobs_service._instance._llmobs_span_writer._encoder) == 1 + else: # child + with llmobs_service.workflow(): + with llmobs_service.task(): + pass + assert len(llmobs_service._instance._llmobs_span_writer._encoder) == 2 + llmobs_service.disable() + os._exit(12) + + _, status = os.waitpid(pid, 0) + exit_code = os.WEXITSTATUS(status) + assert exit_code == 12 + llmobs_service.disable()