From 88da778ba284dd5901d091e861681da03aa44d64 Mon Sep 17 00:00:00 2001 From: xiehuc Date: Sun, 3 Dec 2023 00:58:20 +0800 Subject: [PATCH] improve tracestate performance (#4722) * improve tracestate performance * use string.Builder to directly construct the result * reduce the redundant copying during Insert * avoid using regex * fix lint * revert changelog * update comment * refine code * fix lint * fix unittest * Update trace/tracestate.go Co-authored-by: Tyler Yahn --------- Co-authored-by: Chester Cheung Co-authored-by: Tyler Yahn --- CHANGELOG.md | 1 + trace/tracestate.go | 197 +++++++++++++++++++++------- trace/tracestate_benchkmark_test.go | 58 ++++++++ trace/tracestate_test.go | 154 +++++++++++----------- 4 files changed, 289 insertions(+), 121 deletions(-) create mode 100644 trace/tracestate_benchkmark_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index f6a05cca4a0..32b9d6b9636 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Changed +- Improve `go.opentelemetry.io/otel/trace.TraceState`'s performance. (#4722) - Improve `go.opentelemetry.io/otel/propagation.TraceContext`'s performance. (#4721) ### Added diff --git a/trace/tracestate.go b/trace/tracestate.go index d1e47ca2faa..db936ba5b73 100644 --- a/trace/tracestate.go +++ b/trace/tracestate.go @@ -17,20 +17,14 @@ package trace // import "go.opentelemetry.io/otel/trace" import ( "encoding/json" "fmt" - "regexp" "strings" ) const ( maxListMembers = 32 - listDelimiter = "," - - // based on the W3C Trace Context specification, see - // https://www.w3.org/TR/trace-context-1/#tracestate-header - noTenantKeyFormat = `[a-z][_0-9a-z\-\*\/]*` - withTenantKeyFormat = `[a-z0-9][_0-9a-z\-\*\/]*@[a-z][_0-9a-z\-\*\/]*` - valueFormat = `[\x20-\x2b\x2d-\x3c\x3e-\x7e]*[\x21-\x2b\x2d-\x3c\x3e-\x7e]` + listDelimiters = "," + memberDelimiter = "=" errInvalidKey errorConst = "invalid tracestate key" errInvalidValue errorConst = "invalid tracestate value" @@ -39,43 +33,128 @@ const ( errDuplicate errorConst = "duplicate list-member in tracestate" ) -var ( - noTenantKeyRe = regexp.MustCompile(`^` + noTenantKeyFormat + `$`) - withTenantKeyRe = regexp.MustCompile(`^` + withTenantKeyFormat + `$`) - valueRe = regexp.MustCompile(`^` + valueFormat + `$`) - memberRe = regexp.MustCompile(`^\s*((?:` + noTenantKeyFormat + `)|(?:` + withTenantKeyFormat + `))=(` + valueFormat + `)\s*$`) -) - type member struct { Key string Value string } -func newMember(key, value string) (member, error) { - if len(key) > 256 { - return member{}, fmt.Errorf("%w: %s", errInvalidKey, key) +// according to (chr = %x20 / (nblk-char = %x21-2B / %x2D-3C / %x3E-7E) ) +// means (chr = %x20-2B / %x2D-3C / %x3E-7E) . +func checkValueChar(v byte) bool { + return v >= '\x20' && v <= '\x7e' && v != '\x2c' && v != '\x3d' +} + +// according to (nblk-chr = %x21-2B / %x2D-3C / %x3E-7E) . +func checkValueLast(v byte) bool { + return v >= '\x21' && v <= '\x7e' && v != '\x2c' && v != '\x3d' +} + +// based on the W3C Trace Context specification +// +// value = (0*255(chr)) nblk-chr +// nblk-chr = %x21-2B / %x2D-3C / %x3E-7E +// chr = %x20 / nblk-chr +// +// see https://www.w3.org/TR/trace-context-1/#value +func checkValue(val string) bool { + n := len(val) + if n == 0 || n > 256 { + return false + } + for i := 0; i < n-1; i++ { + if !checkValueChar(val[i]) { + return false + } } - if !noTenantKeyRe.MatchString(key) { - if !withTenantKeyRe.MatchString(key) { - return member{}, fmt.Errorf("%w: %s", errInvalidKey, key) + return checkValueLast(val[n-1]) +} + +func checkKeyRemain(key string) bool { + // ( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ) + for _, v := range key { + if isAlphaNum(byte(v)) { + continue } - atIndex := strings.LastIndex(key, "@") - if atIndex > 241 || len(key)-1-atIndex > 14 { - return member{}, fmt.Errorf("%w: %s", errInvalidKey, key) + switch v { + case '_', '-', '*', '/': + continue } + return false + } + return true +} + +// according to +// +// simple-key = lcalpha (0*255( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )) +// system-id = lcalpha (0*13( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )) +// +// param n is remain part length, should be 255 in simple-key or 13 in system-id. +func checkKeyPart(key string, n int) bool { + if len(key) == 0 { + return false + } + first := key[0] // key's first char + ret := len(key[1:]) <= n + ret = ret && first >= 'a' && first <= 'z' + return ret && checkKeyRemain(key[1:]) +} + +func isAlphaNum(c byte) bool { + if c >= 'a' && c <= 'z' { + return true } - if len(value) > 256 || !valueRe.MatchString(value) { - return member{}, fmt.Errorf("%w: %s", errInvalidValue, value) + return c >= '0' && c <= '9' +} + +// according to +// +// tenant-id = ( lcalpha / DIGIT ) 0*240( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ) +// +// param n is remain part length, should be 240 exactly. +func checkKeyTenant(key string, n int) bool { + if len(key) == 0 { + return false + } + return isAlphaNum(key[0]) && len(key[1:]) <= n && checkKeyRemain(key[1:]) +} + +// based on the W3C Trace Context specification +// +// key = simple-key / multi-tenant-key +// simple-key = lcalpha (0*255( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )) +// multi-tenant-key = tenant-id "@" system-id +// tenant-id = ( lcalpha / DIGIT ) (0*240( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )) +// system-id = lcalpha (0*13( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )) +// lcalpha = %x61-7A ; a-z +// +// see https://www.w3.org/TR/trace-context-1/#tracestate-header. +func checkKey(key string) bool { + tenant, system, ok := strings.Cut(key, "@") + if !ok { + return checkKeyPart(key, 255) + } + return checkKeyTenant(tenant, 240) && checkKeyPart(system, 13) +} + +func newMember(key, value string) (member, error) { + if !checkKey(key) { + return member{}, errInvalidKey + } + if !checkValue(value) { + return member{}, errInvalidValue } return member{Key: key, Value: value}, nil } func parseMember(m string) (member, error) { - matches := memberRe.FindStringSubmatch(m) - if len(matches) != 3 { + key, val, ok := strings.Cut(m, memberDelimiter) + if !ok { return member{}, fmt.Errorf("%w: %s", errInvalidMember, m) } - result, e := newMember(matches[1], matches[2]) + key = strings.TrimLeft(key, " \t") + val = strings.TrimRight(val, " \t") + result, e := newMember(key, val) if e != nil { return member{}, fmt.Errorf("%w: %s", errInvalidMember, m) } @@ -85,7 +164,7 @@ func parseMember(m string) (member, error) { // String encodes member into a string compliant with the W3C Trace Context // specification. func (m member) String() string { - return fmt.Sprintf("%s=%s", m.Key, m.Value) + return m.Key + "=" + m.Value } // TraceState provides additional vendor-specific trace identification @@ -109,8 +188,8 @@ var _ json.Marshaler = TraceState{} // ParseTraceState attempts to decode a TraceState from the passed // string. It returns an error if the input is invalid according to the W3C // Trace Context specification. -func ParseTraceState(tracestate string) (TraceState, error) { - if tracestate == "" { +func ParseTraceState(ts string) (TraceState, error) { + if ts == "" { return TraceState{}, nil } @@ -120,7 +199,9 @@ func ParseTraceState(tracestate string) (TraceState, error) { var members []member found := make(map[string]struct{}) - for _, memberStr := range strings.Split(tracestate, listDelimiter) { + for ts != "" { + var memberStr string + memberStr, ts, _ = strings.Cut(ts, listDelimiters) if len(memberStr) == 0 { continue } @@ -153,11 +234,29 @@ func (ts TraceState) MarshalJSON() ([]byte, error) { // Trace Context specification. The returned string will be invalid if the // TraceState contains any invalid members. func (ts TraceState) String() string { - members := make([]string, len(ts.list)) - for i, m := range ts.list { - members[i] = m.String() + if len(ts.list) == 0 { + return "" + } + var n int + n += len(ts.list) // member delimiters: '=' + n += len(ts.list) - 1 // list delimiters: ',' + for _, mem := range ts.list { + n += len(mem.Key) + n += len(mem.Value) } - return strings.Join(members, listDelimiter) + + var sb strings.Builder + sb.Grow(n) + _, _ = sb.WriteString(ts.list[0].Key) + _ = sb.WriteByte('=') + _, _ = sb.WriteString(ts.list[0].Value) + for i := 1; i < len(ts.list); i++ { + _ = sb.WriteByte(listDelimiters[0]) + _, _ = sb.WriteString(ts.list[i].Key) + _ = sb.WriteByte('=') + _, _ = sb.WriteString(ts.list[i].Value) + } + return sb.String() } // Get returns the value paired with key from the corresponding TraceState @@ -189,15 +288,25 @@ func (ts TraceState) Insert(key, value string) (TraceState, error) { if err != nil { return ts, err } - - cTS := ts.Delete(key) - if cTS.Len()+1 <= maxListMembers { - cTS.list = append(cTS.list, member{}) + n := len(ts.list) + found := n + for i := range ts.list { + if ts.list[i].Key == key { + found = i + } + } + cTS := TraceState{} + if found == n && n < maxListMembers { + cTS.list = make([]member, n+1) + } else { + cTS.list = make([]member, n) } - // When the number of members exceeds capacity, drop the "right-most". - copy(cTS.list[1:], cTS.list) cTS.list[0] = m - + // When the number of members exceeds capacity, drop the "right-most". + copy(cTS.list[1:], ts.list[0:found]) + if found < n { + copy(cTS.list[1+found:], ts.list[found+1:]) + } return cTS, nil } diff --git a/trace/tracestate_benchkmark_test.go b/trace/tracestate_benchkmark_test.go new file mode 100644 index 00000000000..171e09f00f8 --- /dev/null +++ b/trace/tracestate_benchkmark_test.go @@ -0,0 +1,58 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trace + +import ( + "testing" +) + +func BenchmarkTraceStateParse(b *testing.B) { + for _, test := range testcases { + b.Run(test.name, func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = ParseTraceState(test.in) + } + }) + } +} + +func BenchmarkTraceStateString(b *testing.B) { + for _, test := range testcases { + if len(test.tracestate.list) == 0 { + continue + } + b.Run(test.name, func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = test.tracestate.String() + } + }) + } +} + +func BenchmarkTraceStateInsert(b *testing.B) { + for _, test := range insertTestcase { + b.Run(test.name, func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = test.tracestate.Insert(test.key, test.value) + } + }) + } +} diff --git a/trace/tracestate_test.go b/trace/tracestate_test.go index c2bccd3b7c6..4cd0fbc45f3 100644 --- a/trace/tracestate_test.go +++ b/trace/tracestate_test.go @@ -420,85 +420,85 @@ func TestTraceStateDelete(t *testing.T) { } } -func TestTraceStateInsert(t *testing.T) { - ts := TraceState{list: []member{ - {Key: "key1", Value: "val1"}, - {Key: "key2", Value: "val2"}, - {Key: "key3", Value: "val3"}, - }} +var insertTS = TraceState{list: []member{ + {Key: "key1", Value: "val1"}, + {Key: "key2", Value: "val2"}, + {Key: "key3", Value: "val3"}, +}} - testCases := []struct { - name string - tracestate TraceState - key, value string - expected TraceState - err error - }{ - { - name: "add new", - tracestate: ts, - key: "key4@vendor", - value: "val4", - expected: TraceState{list: []member{ - {Key: "key4@vendor", Value: "val4"}, - {Key: "key1", Value: "val1"}, - {Key: "key2", Value: "val2"}, - {Key: "key3", Value: "val3"}, - }}, - }, - { - name: "replace", - tracestate: ts, - key: "key2", - value: "valX", - expected: TraceState{list: []member{ - {Key: "key2", Value: "valX"}, - {Key: "key1", Value: "val1"}, - {Key: "key3", Value: "val3"}, - }}, - }, - { - name: "invalid key", - tracestate: ts, - key: "key!", - value: "val", - expected: ts, - err: errInvalidKey, - }, - { - name: "invalid value", - tracestate: ts, - key: "key", - value: "v=l", - expected: ts, - err: errInvalidValue, - }, - { - name: "invalid key/value", - tracestate: ts, - key: "key!", - value: "v=l", - expected: ts, - err: errInvalidKey, - }, - { - name: "drop the right-most member(oldest) in queue", - tracestate: maxMembers, - key: "keyx", - value: "valx", - expected: func() TraceState { - // Prepend the new element and remove the oldest one, which is over capacity. - return TraceState{ - list: append( - []member{{Key: "keyx", Value: "valx"}}, - maxMembers.list[:len(maxMembers.list)-1]..., - ), - } - }(), - }, - } +var insertTestcase = []struct { + name string + tracestate TraceState + key, value string + expected TraceState + err error +}{ + { + name: "add new", + tracestate: insertTS, + key: "key4@vendor", + value: "val4", + expected: TraceState{list: []member{ + {Key: "key4@vendor", Value: "val4"}, + {Key: "key1", Value: "val1"}, + {Key: "key2", Value: "val2"}, + {Key: "key3", Value: "val3"}, + }}, + }, + { + name: "replace", + tracestate: insertTS, + key: "key2", + value: "valX", + expected: TraceState{list: []member{ + {Key: "key2", Value: "valX"}, + {Key: "key1", Value: "val1"}, + {Key: "key3", Value: "val3"}, + }}, + }, + { + name: "invalid key", + tracestate: insertTS, + key: "key!", + value: "val", + expected: insertTS, + err: errInvalidKey, + }, + { + name: "invalid value", + tracestate: insertTS, + key: "key", + value: "v=l", + expected: insertTS, + err: errInvalidValue, + }, + { + name: "invalid key/value", + tracestate: insertTS, + key: "key!", + value: "v=l", + expected: insertTS, + err: errInvalidKey, + }, + { + name: "drop the right-most member(oldest) in queue", + tracestate: maxMembers, + key: "keyx", + value: "valx", + expected: func() TraceState { + // Prepend the new element and remove the oldest one, which is over capacity. + return TraceState{ + list: append( + []member{{Key: "keyx", Value: "valx"}}, + maxMembers.list[:len(maxMembers.list)-1]..., + ), + } + }(), + }, +} - for _, tc := range testCases { +func TestTraceStateInsert(t *testing.T) { + for _, tc := range insertTestcase { t.Run(tc.name, func(t *testing.T) { actual, err := tc.tracestate.Insert(tc.key, tc.value) assert.ErrorIs(t, err, tc.err, tc.name)