diff --git a/json.go b/json.go index 15d12c8..f8aaa21 100644 --- a/json.go +++ b/json.go @@ -11,16 +11,26 @@ func (l *Logger) jsonFormatter(keyvals ...interface{}) { jw := &jsonWriter{w: &l.b} jw.start() - for i := 0; i < len(keyvals); i += 2 { - l.jsonFormatterKeyVal(jw, keyvals[i], keyvals[i+1]) + i := 0 + for i < len(keyvals) { + switch kv := keyvals[i].(type) { + case slogAttr: + l.jsonFormatterRoot(jw, kv.Key, kv.Value) + i++ + default: + if i+1 < len(keyvals) { + l.jsonFormatterRoot(jw, keyvals[i], keyvals[i+1]) + } + i += 2 + } } jw.end() l.b.WriteRune('\n') } -func (l *Logger) jsonFormatterKeyVal(jw *jsonWriter, anyKey, value any) { - switch anyKey { +func (l *Logger) jsonFormatterRoot(jw *jsonWriter, key, value any) { + switch key { case TimestampKey: if t, ok := value.(time.Time); ok { jw.objectItem(TimestampKey, t.Format(l.timeFormat)) @@ -42,22 +52,43 @@ func (l *Logger) jsonFormatterKeyVal(jw *jsonWriter, anyKey, value any) { jw.objectItem(MessageKey, fmt.Sprint(msg)) } default: - switch k := anyKey.(type) { - case fmt.Stringer: - jw.objectKey(k.String()) - case error: - jw.objectKey(k.Error()) - default: - jw.objectKey(fmt.Sprint(k)) - } - switch v := value.(type) { - case error: - jw.objectValue(v.Error()) - case fmt.Stringer: - jw.objectValue(v.String()) - default: - jw.objectValue(v) + l.jsonFormatterItem(jw, key, value) + } +} + +func (l *Logger) jsonFormatterItem(jw *jsonWriter, key, value any) { + switch k := key.(type) { + case fmt.Stringer: + jw.objectKey(k.String()) + case error: + jw.objectKey(k.Error()) + default: + jw.objectKey(fmt.Sprint(k)) + } + switch v := value.(type) { + case error: + jw.objectValue(v.Error()) + case slogLogValuer: + l.writeSlogValue(jw, v.LogValue()) + case slogValue: + l.writeSlogValue(jw, v.Resolve()) + case fmt.Stringer: + jw.objectValue(v.String()) + default: + jw.objectValue(v) + } +} + +func (l *Logger) writeSlogValue(jw *jsonWriter, v slogValue) { + switch v.Kind() { + case slogKindGroup: + jw.start() + for _, attr := range v.Group() { + l.jsonFormatterItem(jw, attr.Key, attr.Value) } + jw.end() + default: + jw.objectValue(v.Any()) } } diff --git a/logger_121.go b/logger_121.go index cd49a09..73bee9c 100644 --- a/logger_121.go +++ b/logger_121.go @@ -10,6 +10,15 @@ import ( "sync/atomic" ) +// type aliases for slog. +type ( + slogAttr = slog.Attr + slogValue = slog.Value + slogLogValuer = slog.LogValuer +) + +const slogKindGroup = slog.KindGroup + // Enabled reports whether the logger is enabled for the given level. // // Implements slog.Handler. @@ -27,7 +36,7 @@ func (l *Logger) Handle(ctx context.Context, record slog.Record) error { fields := make([]interface{}, 0, record.NumAttrs()*2) record.Attrs(func(a slog.Attr) bool { - fields = append(fields, a.Key, a.Value.String()) + fields = append(fields, a.Key, a.Value) return true }) // Get the caller frame using the record's PC. diff --git a/logger_121_test.go b/logger_121_test.go index d893ebc..1090816 100644 --- a/logger_121_test.go +++ b/logger_121_test.go @@ -6,11 +6,10 @@ package log import ( "bytes" "context" + "log/slog" "testing" "time" - "log/slog" - "github.com/stretchr/testify/assert" ) @@ -183,3 +182,95 @@ func TestSlogCustomLevel(t *testing.T) { }) } } + +type testLogValue struct { + v slog.Value +} + +func (v testLogValue) LogValue() slog.Value { + return v.v +} + +func TestSlogAttr(t *testing.T) { + cases := []struct { + name string + expected string + kvs []interface{} + }{ + { + name: "any", + expected: `{"level":"info","msg":"message","any":42}` + "\n", + kvs: []any{"any", slog.AnyValue(42)}, + }, + { + name: "bool", + expected: `{"level":"info","msg":"message","bool":false}` + "\n", + kvs: []any{"bool", slog.BoolValue(false)}, + }, + { + name: "duration", + expected: `{"level":"info","msg":"message","duration":10800000000000}` + "\n", + kvs: []any{"duration", slog.DurationValue(3 * time.Hour)}, + }, + { + name: "float64", + expected: `{"level":"info","msg":"message","float64":123}` + "\n", + kvs: []any{"float64", slog.Float64Value(123)}, + }, + { + name: "string", + expected: `{"level":"info","msg":"message","string":"hello"}` + "\n", + kvs: []any{"string", slog.StringValue("hello")}, + }, + { + name: "time", + expected: `{"level":"info","msg":"message","_time":"1970-01-01T00:00:00Z"}` + "\n", + kvs: []any{"_time", slog.TimeValue(time.Unix(0, 0).UTC())}, + }, + { + name: "uint64", + expected: `{"level":"info","msg":"message","uint64":42}` + "\n", + kvs: []any{"uint64", slog.Uint64Value(42)}, + }, + { + name: "group", + expected: `{"level":"info","msg":"message","g":{"b":true}}` + "\n", + kvs: []any{slog.Group("g", slog.Bool("b", true))}, + }, + { + name: "log valuer", + expected: `{"level":"info","msg":"message","lv":42}` + "\n", + kvs: []any{ + "lv", testLogValue{slog.AnyValue(42)}, + }, + }, + { + name: "log valuer", + expected: `{"level":"info","msg":"message","lv":{"first":"hello","last":"world"}}` + "\n", + kvs: []any{ + "lv", testLogValue{slog.GroupValue( + slog.String("first", "hello"), + slog.String("last", "world"), + )}, + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + // expect same output from slog and log + var buf bytes.Buffer + l := NewWithOptions(&buf, Options{Formatter: JSONFormatter}) + l.Info("message", c.kvs...) + assert.Equal(t, c.expected, buf.String()) + + buf.Truncate(0) + sl := slog.New(l) + sl.Info("message", c.kvs...) + assert.Equal(t, c.expected, buf.String()) + }) + } +} diff --git a/logger_no121.go b/logger_no121.go index bf2eaf9..bf28cf4 100644 --- a/logger_no121.go +++ b/logger_no121.go @@ -11,6 +11,15 @@ import ( "golang.org/x/exp/slog" ) +// type alises for slog. +type ( + slogAttr = slog.Attr + slogValue = slog.Value + slogLogValuer = slog.LogValuer +) + +const slogKindGroup = slog.KindGroup + // Enabled reports whether the logger is enabled for the given level. // // Implements slog.Handler. @@ -24,7 +33,7 @@ func (l *Logger) Enabled(_ context.Context, level slog.Level) bool { func (l *Logger) Handle(_ context.Context, record slog.Record) error { fields := make([]interface{}, 0, record.NumAttrs()*2) record.Attrs(func(a slog.Attr) bool { - fields = append(fields, a.Key, a.Value.String()) + fields = append(fields, a.Key, a.Value) return true }) // Get the caller frame using the record's PC.