Skip to content

Commit

Permalink
refactor: add visitor pattern for tracing fields
Browse files Browse the repository at this point in the history
By adding a visitor pattern, we can further customize the output of
each event. For example, the `message` field was never actually
being formatted with the provided `message_name` field.

Also, this fixes an issue where the number of nested spans exceeds
one. Fields from any previous spans were being ignored and not formatted
in the final output.
  • Loading branch information
cmackenzie1 committed Oct 9, 2023
1 parent 2806036 commit ae8b401
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 38 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.49"
tracing-core = "0.1.31"
tracing-serde = "0.1.3"
tracing-subscriber = "0.3.17"

[dev-dependencies]
Expand Down
105 changes: 72 additions & 33 deletions src/formatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ use std::fmt;
use serde::{ser::SerializeMap, Serializer};

use tracing_core::{Event, Subscriber};
use tracing_serde::fields::AsMap;
use tracing_subscriber::{
fmt::{format, FmtContext, FormatEvent, FormatFields, FormattedFields},
registry::LookupSpan,
};

use crate::Error;
use crate::{visitor, Error};

const DEFAULT_TIMESTAMP_FORMAT: crate::TimestampFormat = crate::TimestampFormat::Rfc3339;
const DEFAULT_LEVEL_NAME: &str = "level";
const DEFAULT_MESSAGE_NAME: &str = "message";
const DEFAULT_TARGET_NAME: &str = "target";
const DEFAULT_TIMESTAMP_NAME: &str = "timestamp";
const DEFAULT_FLATTEN_FIELDS: bool = true;

/// A JSON formatter for tracing events.
/// This is used to format the event field in the JSON output.
Expand All @@ -25,12 +31,12 @@ pub struct JsonEventFormatter {
impl Default for JsonEventFormatter {
fn default() -> Self {
Self {
level_name: "level",
message_name: "message",
target_name: "target",
timestamp_name: "timestamp",
timestamp_format: crate::TimestampFormat::Rfc3339,
flatten_fields: true,
level_name: DEFAULT_LEVEL_NAME,
message_name: DEFAULT_MESSAGE_NAME,
target_name: DEFAULT_TARGET_NAME,
timestamp_name: DEFAULT_TIMESTAMP_NAME,
timestamp_format: DEFAULT_TIMESTAMP_FORMAT,
flatten_fields: DEFAULT_FLATTEN_FIELDS,
}
}
}
Expand Down Expand Up @@ -84,8 +90,7 @@ where
) -> fmt::Result {
let now = chrono::Utc::now();

let mut buffer = Vec::new();
let mut binding = serde_json::Serializer::new(&mut buffer);
let mut binding = serde_json::Serializer::new(Vec::new());
let mut serializer = binding.serialize_map(None).map_err(Error::Serde)?;

serializer
Expand Down Expand Up @@ -118,42 +123,74 @@ where
.serialize_entry(self.target_name, event.metadata().target())
.map_err(Error::Serde)?;

let msg_name = if self.message_name != DEFAULT_MESSAGE_NAME {
Some(self.message_name)
} else {
None
};

if self.flatten_fields {
let mut visitor = tracing_serde::SerdeMapVisitor::new(serializer);
// record fields in the top-level map
let mut visitor = visitor::Visitor::new(&mut serializer, msg_name);
event.record(&mut visitor);

serializer = visitor.take_serializer().map_err(|_| Error::Unknown)?;
visitor.finish().map_err(Error::Serde)?;
} else {
serializer
.serialize_entry("fields", &event.field_map())
.map_err(Error::Serde)?;
};
// record fields in a nested map under the key "fields"
let mut binding = serde_json::Serializer::pretty(Vec::new());
let mut field_serializer = binding.serialize_map(None).map_err(Error::Serde)?;
let mut visitor = visitor::Visitor::new(&mut field_serializer, msg_name);
event.record(&mut visitor);
visitor.finish().map_err(Error::Serde)?;
field_serializer.end().map_err(Error::Serde)?;

// Add the new map to the top-level map
let obj: Option<serde_json::Value> = serde_json::from_str(
std::str::from_utf8(&binding.into_inner()).map_err(Error::Utf8)?,
)
.ok();
if matches!(obj, Some(serde_json::Value::Object(_))) {
let obj = obj.expect("matched object");
serializer
.serialize_entry("fields", &obj)
.map_err(Error::Serde)?;
}
}

// Write all fields from spans
if let Some(leaf_span) = ctx.lookup_current() {
for span in leaf_span.scope().from_root() {
let ext = span.extensions();
let data = ext
let formatted_fields = ext
.get::<FormattedFields<N>>()
.expect("Unable to find FormattedFields in extensions; this is a bug");

if !data.is_empty() {
let obj: Option<serde_json::Value> = serde_json::from_str(data.as_str()).ok();
if matches!(obj, Some(serde_json::Value::Object(_))) {
let obj = obj.expect("matched object");
for (key, value) in obj.as_object().unwrap() {
serializer
.serialize_entry(key, value)
.map_err(Error::Serde)?;
// formatted_fields actually contains multiple ndjson objects, one for every time a spans fields are formatted.
// re-parse these into JSON for serialization into the final map. Any fields redefined in a subsequent span
// will overwrite the previous value.
// TODO(cmackenzie1): There has to be a better way to do this.
for data in formatted_fields.split('\n') {
if !data.is_empty() {
let obj: Option<serde_json::Value> = serde_json::from_str(data).ok();
if matches!(obj, Some(serde_json::Value::Object(_))) {
let obj = obj.expect("matched object");
for (key, value) in obj.as_object().unwrap() {
serializer
.serialize_entry(key, value)
.map_err(Error::Serde)?;
}
}
}
}
}
}

serializer.end().map_err(Error::Serde)?;
writer.write_str(std::str::from_utf8(&buffer).map_err(Error::Utf8)?)?;
writer.write_char('\n')?;

writeln!(
writer,
"{}",
std::str::from_utf8(&binding.into_inner()).map_err(Error::Utf8)?
)?;

Ok(())
}
Expand All @@ -178,16 +215,18 @@ impl<'writer> FormatFields<'writer> for FieldsFormatter {
where
R: tracing_subscriber::field::RecordFields,
{
let mut buffer = Vec::new();
let mut binding = serde_json::Serializer::new(&mut buffer);
let mut binding = serde_json::Serializer::new(Vec::new());
let mut serializer = binding.serialize_map(None).map_err(Error::Serde)?;
let mut visitor = tracing_serde::SerdeMapVisitor::new(serializer);
let mut visitor = visitor::Visitor::new(&mut serializer, None);

fields.record(&mut visitor);

serializer = visitor.take_serializer().map_err(|_| Error::Unknown)?;
serializer.end().map_err(Error::Serde)?;
writer.write_str(std::str::from_utf8(&buffer).map_err(Error::Utf8)?)?;
writeln!(
writer,
"{}",
std::str::from_utf8(&binding.into_inner()).map_err(Error::Utf8)?
)?;

Ok(())
}
Expand Down
43 changes: 39 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
//!
//! Licensed under MIT license [LICENSE](./LICENSE)
mod formatter;
mod visitor;

use tracing_core::Subscriber;
use tracing_subscriber::fmt::{Layer, SubscriberBuilder};
Expand Down Expand Up @@ -97,8 +98,6 @@ enum Error {
Serde(#[from] serde_json::Error),
#[error("utf8 error: {0}")]
Utf8(#[from] std::str::Utf8Error),
#[error("unknown error")]
Unknown,
}

impl From<Error> for std::fmt::Error {
Expand Down Expand Up @@ -235,9 +234,17 @@ mod tests {

use super::*;

use tracing::{debug, error, info, info_span, trace, warn};
use tracing::{debug, error, info, info_span, instrument, trace, warn};
use tracing_subscriber::prelude::*;

#[instrument]
fn some_function(a: u32, b: u32) {
let span = info_span!("some_span", a = a, b = b);
span.in_scope(|| {
info!("some message from inside a span");
});
}

#[test]
fn test_json_event_formatter() {
let subscriber = tracing_subscriber::registry().with(builder().layer());
Expand Down Expand Up @@ -270,9 +277,10 @@ mod tests {
let subscriber = tracing_subscriber::registry().with(
builder()
.with_level_name("severity")
.with_message_name("message")
.with_message_name("msg")
.with_timestamp_name("ts")
.with_timestamp_format(TimestampFormat::Unix)
.with_flatten_fields(false)
.layer(),
);

Expand Down Expand Up @@ -301,4 +309,31 @@ mod tests {
});
});
}

#[test]
fn test_nested_spans() {
let subscriber = tracing_subscriber::registry().with(builder().layer());

tracing::subscriber::with_default(subscriber, || {
let span = info_span!(
"test_span",
person.firstname = "cole",
person.lastname = "mackenzie",
later = tracing::field::Empty,
);
span.in_scope(|| {
info!("some message from inside a info_span");
let inner = info_span!("inner_span", a = "b", c = "d", inner_span = true);
inner.in_scope(|| {
info!(
inner_span_field = true,
later = "populated from inside a span",
"some message from inside a info_span",
);
});
});

some_function(1, 2);
});
}
}
107 changes: 107 additions & 0 deletions src/visitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use tracing_core::field::Visit;

pub struct Visitor<'a, W>
where
W: serde::ser::SerializeMap,
{
serializer: &'a mut W,
state: Result<(), W::Error>,
overwrite_message_name: Option<&'static str>,
}

impl<'a, W> Visitor<'a, W>
where
W: serde::ser::SerializeMap,
{
pub fn new(serializer: &'a mut W, overwrite_message_name: Option<&'static str>) -> Self {
Self {
serializer,
state: Ok(()),
overwrite_message_name,
}
}

pub fn finish(self) -> Result<(), W::Error> {
self.state
}
}

impl<'a, W> Visit for Visitor<'a, W>
where
W: serde::ser::SerializeMap,
{
fn record_f64(&mut self, field: &tracing_core::Field, value: f64) {
if self.state.is_ok() {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}

fn record_i64(&mut self, field: &tracing_core::Field, value: i64) {
if self.state.is_ok() {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}

fn record_u64(&mut self, field: &tracing_core::Field, value: u64) {
if self.state.is_ok() {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}

fn record_i128(&mut self, field: &tracing_core::Field, value: i128) {
if self.state.is_ok() {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}

fn record_u128(&mut self, field: &tracing_core::Field, value: u128) {
if self.state.is_ok() {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}

fn record_bool(&mut self, field: &tracing_core::Field, value: bool) {
if self.state.is_ok() {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}

fn record_str(&mut self, field: &tracing_core::Field, value: &str) {
if self.state.is_ok() {
if self.overwrite_message_name.is_some() && field.name() == "message" {
self.state = self
.serializer
.serialize_entry(self.overwrite_message_name.expect("message"), &value);
} else {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
}

fn record_error(
&mut self,
field: &tracing_core::Field,
value: &(dyn std::error::Error + 'static),
) {
if self.state.is_ok() {
self.state = self
.serializer
.serialize_entry(field.name(), &format_args!("{}", value).to_string());
}
}

fn record_debug(&mut self, field: &tracing_core::Field, value: &dyn std::fmt::Debug) {
if self.state.is_ok() {
if self.overwrite_message_name.is_some() && field.name() == "message" {
self.state = self.serializer.serialize_entry(
self.overwrite_message_name.expect("message"),
&format_args!("{:?}", value).to_string(),
);
} else {
self.state = self
.serializer
.serialize_entry(field.name(), &format_args!("{:?}", value).to_string());
}
}
}
}

0 comments on commit ae8b401

Please sign in to comment.