Skip to content

Commit

Permalink
feat(llmobs): add retrieval and embedding spans (#9134)
Browse files Browse the repository at this point in the history
This PR adds support for submitting embedding and retrieval type spans
for LLM Observability, both via `LLMObs.{retrieval/embedding}` and
`@ddtrace.llmobs.decorators.{retrieval/embedding}`.
Additionally, this PR adds a public helper class
`ddtrace.llmobs.utils.Documents` for users to create SDK-compatible
input/output annotation objects for Embedding/Retrieval spans.

Embedding spans require a model name to be set, and also optionally
accepts model provider values (will default to `custom`). Embedding
spans can be annotated with input strings, dictionaries, or a list of
dictionaries, which will be cast as `Documents` when submitted to
LLMObs. Embedding spans can be annotated with output strings or any JSON
serializable value.

Retrieval spans can be annotated with input strings or any JSON
serializable value. Retrieval spans can also be annotated with output
strings, dictionaries, or a list of dictionaries, which will be cast as
`Documents` when submitted to LLMObs.

This PR also introduces a class of type
`ddtrace.llmobs.utils.Documents`, which can be used to convert arguments
to be tagged as input/output documents. The `Documents` TypedDict object
can contain the following fields:
- `name`: str
- `id`: str
- `text`: str
- `score`: int/float

## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.

## Reviewer Checklist

- [x] Title is accurate
- [x] All changes are related to the pull request's stated goal
- [x] Description motivates each change
- [x] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [x] Testing strategy adequately addresses listed risks
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] Release note makes sense to a user of the library
- [x] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [x] Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)

(cherry picked from commit d10e081)
  • Loading branch information
Yun-Kim authored and github-actions[bot] committed May 22, 2024
1 parent af1c3e2 commit 8f8291f
Show file tree
Hide file tree
Showing 8 changed files with 512 additions and 43 deletions.
2 changes: 2 additions & 0 deletions ddtrace/llmobs/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
MODEL_NAME = "_ml_obs.meta.model_name"
MODEL_PROVIDER = "_ml_obs.meta.model_provider"

INPUT_DOCUMENTS = "_ml_obs.meta.input.documents"
INPUT_MESSAGES = "_ml_obs.meta.input.messages"
INPUT_VALUE = "_ml_obs.meta.input.value"
INPUT_PARAMETERS = "_ml_obs.meta.input.parameters"

OUTPUT_DOCUMENTS = "_ml_obs.meta.output.documents"
OUTPUT_MESSAGES = "_ml_obs.meta.output.messages"
OUTPUT_VALUE = "_ml_obs.meta.output.value"
114 changes: 114 additions & 0 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ddtrace.internal import atexit
from ddtrace.internal.logger import get_logger
from ddtrace.internal.service import Service
from ddtrace.llmobs._constants import INPUT_DOCUMENTS
from ddtrace.llmobs._constants import INPUT_MESSAGES
from ddtrace.llmobs._constants import INPUT_PARAMETERS
from ddtrace.llmobs._constants import INPUT_VALUE
Expand All @@ -20,6 +21,7 @@
from ddtrace.llmobs._constants import ML_APP
from ddtrace.llmobs._constants import MODEL_NAME
from ddtrace.llmobs._constants import MODEL_PROVIDER
from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS
from ddtrace.llmobs._constants import OUTPUT_MESSAGES
from ddtrace.llmobs._constants import OUTPUT_VALUE
from ddtrace.llmobs._constants import SESSION_ID
Expand All @@ -30,6 +32,7 @@
from ddtrace.llmobs._utils import _get_session_id
from ddtrace.llmobs._writer import LLMObsEvalMetricWriter
from ddtrace.llmobs._writer import LLMObsSpanWriter
from ddtrace.llmobs.utils import Documents
from ddtrace.llmobs.utils import ExportedLLMObsSpan
from ddtrace.llmobs.utils import Messages

Expand Down Expand Up @@ -270,6 +273,64 @@ def workflow(
return None
return cls._instance._start_span("workflow", name=name, session_id=session_id, ml_app=ml_app)

@classmethod
def embedding(
cls,
model_name: str,
name: Optional[str] = None,
model_provider: Optional[str] = None,
session_id: Optional[str] = None,
ml_app: Optional[str] = None,
) -> Optional[Span]:
"""
Trace a call to an embedding model or function to create an embedding.
:param str model_name: The name of the invoked embedding model.
:param str name: The name of the traced operation. If not provided, a default value of "embedding" will be set.
:param str model_provider: The name of the invoked LLM provider (ex: openai, bedrock).
If not provided, a default value of "custom" will be set.
:param str session_id: The ID of the underlying user session. Required for tracking sessions.
:param str ml_app: The name of the ML application that the agent is orchestrating. If not provided, the default
value DD_LLMOBS_APP_NAME will be set.
:returns: The Span object representing the traced operation.
"""
if cls.enabled is False or cls._instance is None:
log.warning("LLMObs.embedding() cannot be used while LLMObs is disabled.")
return None
if not model_name:
log.warning("model_name must be the specified name of the invoked model.")
return None
if model_provider is None:
model_provider = "custom"
return cls._instance._start_span(
"embedding",
name,
model_name=model_name,
model_provider=model_provider,
session_id=session_id,
ml_app=ml_app,
)

@classmethod
def retrieval(
cls, name: Optional[str] = None, session_id: Optional[str] = None, ml_app: Optional[str] = None
) -> Optional[Span]:
"""
Trace a vector search operation involving a list of documents being returned from an external knowledge base.
:param str name: The name of the traced operation. If not provided, a default value of "workflow" will be set.
:param str session_id: The ID of the underlying user session. Required for tracking sessions.
:param str ml_app: The name of the ML application that the agent is orchestrating. If not provided, the default
value DD_LLMOBS_APP_NAME will be set.
:returns: The Span object representing the traced operation.
"""
if cls.enabled is False or cls._instance is None:
log.warning("LLMObs.retrieval() cannot be used while LLMObs is disabled.")
return None
return cls._instance._start_span("retrieval", name=name, session_id=session_id, ml_app=ml_app)

@classmethod
def annotate(
cls,
Expand All @@ -290,10 +351,15 @@ def annotate(
:param input_data: A single input string, dictionary, or a list of dictionaries based on the span kind:
- llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."},
or a list of dictionaries with the same signature.
- embedding spans: accepts a string, list of strings, or a dictionary of form
{"text": "...", ...} or a list of dictionaries with the same signature.
- other: any JSON serializable type.
:param output_data: A single output string, dictionary, or a list of dictionaries based on the span kind:
- llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."},
or a list of dictionaries with the same signature.
- retrieval spans: a dictionary containing any of the key value pairs
{"name": str, "id": str, "text": str, "score": float},
or a list of dictionaries with the same signature.
- other: any JSON serializable type.
:param parameters: (DEPRECATED) Dictionary of JSON serializable key-value pairs to set as input parameters.
:param metadata: Dictionary of JSON serializable key-value metadata pairs relevant to the input/output operation
Expand Down Expand Up @@ -327,6 +393,10 @@ def annotate(
if input_data or output_data:
if span_kind == "llm":
cls._tag_llm_io(span, input_messages=input_data, output_messages=output_data)
elif span_kind == "embedding":
cls._tag_embedding_io(span, input_documents=input_data, output_text=output_data)
elif span_kind == "retrieval":
cls._tag_retrieval_io(span, input_text=input_data, output_documents=output_data)
else:
cls._tag_text_io(span, input_value=input_data, output_value=output_data)
if metadata is not None:
Expand Down Expand Up @@ -371,6 +441,50 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None):
except (TypeError, AttributeError):
log.warning("Failed to parse output messages.", exc_info=True)

@classmethod
def _tag_embedding_io(cls, span, input_documents=None, output_text=None):
"""Tags input documents and output text for embedding-kind spans.
Will be mapped to span's `meta.{input,output}.text` fields.
"""
if input_documents is not None:
try:
if not isinstance(input_documents, Documents):
input_documents = Documents(input_documents)
if input_documents.documents:
span.set_tag_str(INPUT_DOCUMENTS, json.dumps(input_documents.documents))
except (TypeError, AttributeError):
log.warning("Failed to parse input documents.", exc_info=True)
if output_text is not None:
if isinstance(output_text, str):
span.set_tag_str(OUTPUT_VALUE, output_text)
else:
try:
span.set_tag_str(OUTPUT_VALUE, json.dumps(output_text))
except TypeError:
log.warning("Failed to parse output text. Output text must be JSON serializable.")

@classmethod
def _tag_retrieval_io(cls, span, input_text=None, output_documents=None):
"""Tags input text and output documents for retrieval-kind spans.
Will be mapped to span's `meta.{input,output}.text` fields.
"""
if input_text is not None:
if isinstance(input_text, str):
span.set_tag_str(INPUT_VALUE, input_text)
else:
try:
span.set_tag_str(INPUT_VALUE, json.dumps(input_text))
except TypeError:
log.warning("Failed to parse input text. Input text must be JSON serializable.")
if output_documents is not None:
try:
if not isinstance(output_documents, Documents):
output_documents = Documents(output_documents)
if output_documents.documents:
span.set_tag_str(OUTPUT_DOCUMENTS, json.dumps(output_documents.documents))
except (TypeError, AttributeError):
log.warning("Failed to parse output documents.", exc_info=True)

@classmethod
def _tag_text_io(cls, span, input_value=None, output_value=None):
"""Tags input/output values for non-LLM kind spans.
Expand Down
8 changes: 7 additions & 1 deletion ddtrace/llmobs/_trace_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ddtrace.ext import SpanTypes
from ddtrace.internal.logger import get_logger
from ddtrace.internal.utils.formats import asbool
from ddtrace.llmobs._constants import INPUT_DOCUMENTS
from ddtrace.llmobs._constants import INPUT_MESSAGES
from ddtrace.llmobs._constants import INPUT_PARAMETERS
from ddtrace.llmobs._constants import INPUT_VALUE
Expand All @@ -23,6 +24,7 @@
from ddtrace.llmobs._constants import ML_APP
from ddtrace.llmobs._constants import MODEL_NAME
from ddtrace.llmobs._constants import MODEL_PROVIDER
from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS
from ddtrace.llmobs._constants import OUTPUT_MESSAGES
from ddtrace.llmobs._constants import OUTPUT_VALUE
from ddtrace.llmobs._constants import SESSION_ID
Expand Down Expand Up @@ -65,7 +67,7 @@ def _llmobs_span_event(self, span: Span) -> Dict[str, Any]:
"""Span event object structure."""
span_kind = span._meta.pop(SPAN_KIND)
meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}}
if span_kind == "llm" and span.get_tag(MODEL_NAME) is not None:
if span_kind in ("llm", "embedding") and span.get_tag(MODEL_NAME) is not None:
meta["model_name"] = span._meta.pop(MODEL_NAME)
meta["model_provider"] = span._meta.pop(MODEL_PROVIDER, "custom").lower()
if span.get_tag(METADATA) is not None:
Expand All @@ -78,8 +80,12 @@ def _llmobs_span_event(self, span: Span) -> Dict[str, Any]:
meta["input"]["value"] = span._meta.pop(INPUT_VALUE)
if span_kind == "llm" and span.get_tag(OUTPUT_MESSAGES) is not None:
meta["output"]["messages"] = json.loads(span._meta.pop(OUTPUT_MESSAGES))
if span_kind == "embedding" and span.get_tag(INPUT_DOCUMENTS) is not None:
meta["input"]["documents"] = json.loads(span._meta.pop(INPUT_DOCUMENTS))
if span.get_tag(OUTPUT_VALUE) is not None:
meta["output"]["value"] = span._meta.pop(OUTPUT_VALUE)
if span_kind == "retrieval" and span.get_tag(OUTPUT_DOCUMENTS) is not None:
meta["output"]["documents"] = json.loads(span._meta.pop(OUTPUT_DOCUMENTS))
if span.error:
meta[ERROR_MSG] = span.get_tag(ERROR_MSG)
meta[ERROR_STACK] = span.get_tag(ERROR_STACK)
Expand Down
65 changes: 38 additions & 27 deletions ddtrace/llmobs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,42 @@
log = get_logger(__name__)


def llm(
model_name: str,
model_provider: Optional[str] = None,
name: Optional[str] = None,
session_id: Optional[str] = None,
ml_app: Optional[str] = None,
):
def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not LLMObs.enabled or LLMObs._instance is None:
log.warning("LLMObs.llm() cannot be used while LLMObs is disabled.")
return func(*args, **kwargs)
span_name = name
if span_name is None:
span_name = func.__name__
with LLMObs.llm(
model_name=model_name,
model_provider=model_provider,
name=span_name,
session_id=session_id,
ml_app=ml_app,
):
return func(*args, **kwargs)
def _model_decorator(operation_kind):
def decorator(
model_name: str,
original_func: Optional[Callable] = None,
model_provider: Optional[str] = None,
name: Optional[str] = None,
session_id: Optional[str] = None,
ml_app: Optional[str] = None,
):
def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not LLMObs.enabled or LLMObs._instance is None:
log.warning("LLMObs.%s() cannot be used while LLMObs is disabled.", operation_kind)
return func(*args, **kwargs)
traced_model_name = model_name
if traced_model_name is None:
raise TypeError("model_name is required for LLMObs.{}()".format(operation_kind))
span_name = name
if span_name is None:
span_name = func.__name__
traced_operation = getattr(LLMObs, operation_kind, "llm")
with traced_operation(
model_name=model_name,
model_provider=model_provider,
name=span_name,
session_id=session_id,
ml_app=ml_app,
):
return func(*args, **kwargs)

return wrapper
return wrapper

return inner

return inner
return decorator


def _llmobs_decorator(operation_kind):
Expand All @@ -50,7 +58,7 @@ def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not LLMObs.enabled or LLMObs._instance is None:
log.warning("LLMObs.{}() cannot be used while LLMObs is disabled.", operation_kind)
log.warning("LLMObs.%s() cannot be used while LLMObs is disabled.", operation_kind)
return func(*args, **kwargs)
span_name = name
if span_name is None:
Expand All @@ -68,7 +76,10 @@ def wrapper(*args, **kwargs):
return decorator


llm = _model_decorator("llm")
embedding = _model_decorator("embedding")
workflow = _llmobs_decorator("workflow")
task = _llmobs_decorator("task")
tool = _llmobs_decorator("tool")
retrieval = _llmobs_decorator("retrieval")
agent = _llmobs_decorator("agent")
34 changes: 34 additions & 0 deletions ddtrace/llmobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


ExportedLLMObsSpan = TypedDict("ExportedLLMObsSpan", {"span_id": str, "trace_id": str})
Document = TypedDict("Document", {"name": str, "id": str, "text": str, "score": float}, total=False)
Message = TypedDict("Message", {"content": str, "role": str}, total=False)


Expand All @@ -40,3 +41,36 @@ def __init__(self, messages: Union[List[Dict[str, str]], Dict[str, str], str]):
if not isinstance(role, str):
raise TypeError("Message role must be a string, and one of .")
self.messages.append(Message(content=content, role=role))


class Documents:
def __init__(self, documents: Union[List[Dict[str, str]], Dict[str, str], str]):
self.documents = []
if not isinstance(documents, list):
documents = [documents] # type: ignore[list-item]
for document in documents:
if isinstance(document, str):
self.documents.append(Document(text=document))
continue
elif not isinstance(document, dict):
raise TypeError("documents must be a string, dictionary, or list of dictionaries.")
document_text = document.get("text")
document_name = document.get("name")
document_id = document.get("id")
document_score = document.get("score")
if not isinstance(document_text, str):
raise TypeError("Document text must be a string.")
formatted_document = Document(text=document_text)
if document_name:
if not isinstance(document_name, str):
raise TypeError("document name must be a string.")
formatted_document["name"] = document_name
if document_id:
if not isinstance(document_id, str):
raise TypeError("document id must be a string.")
formatted_document["id"] = document_id
if document_score:
if not isinstance(document_score, (int, float)):
raise TypeError("document score must be an integer or float.")
formatted_document["score"] = document_score
self.documents.append(formatted_document)
Loading

0 comments on commit 8f8291f

Please sign in to comment.