diff --git a/modules/frontend/config.go b/modules/frontend/config.go index 3d802466004..44f4e0a9e6b 100644 --- a/modules/frontend/config.go +++ b/modules/frontend/config.go @@ -2,13 +2,12 @@ package frontend import ( "flag" - "net/http" "time" "github.com/go-kit/log" "github.com/prometheus/client_golang/prometheus" - "github.com/grafana/tempo/modules/frontend/transport" + "github.com/grafana/tempo/modules/frontend/pipeline" v1 "github.com/grafana/tempo/modules/frontend/v1" "github.com/grafana/tempo/pkg/usagestats" ) @@ -16,14 +15,14 @@ import ( var statVersion = usagestats.NewString("frontend_version") type Config struct { - Config v1.Config `yaml:",inline"` - MaxRetries int `yaml:"max_retries,omitempty"` - Search SearchConfig `yaml:"search"` - TraceByID TraceByIDConfig `yaml:"trace_by_id"` - Metrics MetricsConfig `yaml:"metrics"` - MultiTenantQueriesEnabled bool `yaml:"multi_tenant_queries_enabled"` - ResponseConsumers int `yaml:"response_consumers"` - + Config v1.Config `yaml:",inline"` + MaxRetries int `yaml:"max_retries,omitempty"` + Search SearchConfig `yaml:"search"` + TraceByID TraceByIDConfig `yaml:"trace_by_id"` + Metrics MetricsConfig `yaml:"metrics"` + MultiTenantQueriesEnabled bool `yaml:"multi_tenant_queries_enabled"` + ResponseConsumers int `yaml:"response_consumers"` + Weights pipeline.WeightsConfig `yaml:"weights"` // the maximum time limit that tempo will work on an api request. this includes both // grpc and http requests and applies to all "api" frontend query endpoints such as // traceql, tag search, tag value search, trace by id and all streaming gRPC endpoints. @@ -32,6 +31,9 @@ type Config struct { // A list of regexes for black listing requests, these will apply for every request regardless the endpoint URLDenyList []string `yaml:"url_deny_list,omitempty"` + + RequestWithWeights bool `yaml:"request_with_weights,omitempty"` + RetryWithWeights bool `yaml:"retry_with_weights,omitempty"` } type SearchConfig struct { @@ -95,6 +97,12 @@ func (cfg *Config) RegisterFlagsAndApplyDefaults(string, *flag.FlagSet) { }, SLO: slo, } + cfg.Weights = pipeline.WeightsConfig{ + RequestWithWeights: true, + RetryWithWeights: true, + MaxRegexConditions: 1, + MaxTraceQLConditions: 4, + } // enable multi tenant queries by default cfg.MultiTenantQueriesEnabled = true @@ -107,12 +115,12 @@ type CortexNoQuerierLimits struct{} // Returned RoundTripper can be wrapped in more round-tripper middlewares, and then eventually registered // into HTTP server using the Handler from this package. Returned RoundTripper is always non-nil // (if there are no errors), and it uses the returned frontend (if any). -func InitFrontend(cfg v1.Config, log log.Logger, reg prometheus.Registerer) (http.RoundTripper, *v1.Frontend, error) { +func InitFrontend(cfg v1.Config, log log.Logger, reg prometheus.Registerer) (pipeline.RoundTripper, *v1.Frontend, error) { statVersion.Set("v1") // No scheduler = use original frontend. fr, err := v1.New(cfg, log, reg) if err != nil { return nil, nil, err } - return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr), fr, nil + return fr, fr, nil } diff --git a/modules/frontend/frontend.go b/modules/frontend/frontend.go index 1c5938ee4c6..3976c768b55 100644 --- a/modules/frontend/frontend.go +++ b/modules/frontend/frontend.go @@ -59,7 +59,7 @@ type QueryFrontend struct { var tracer = otel.Tracer("modules/frontend") // New returns a new QueryFrontend -func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempodb.Reader, cacheProvider cache.Provider, apiPrefix string, logger log.Logger, registerer prometheus.Registerer) (*QueryFrontend, error) { +func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader tempodb.Reader, cacheProvider cache.Provider, apiPrefix string, logger log.Logger, registerer prometheus.Registerer) (*QueryFrontend, error) { level.Info(logger).Log("msg", "creating middleware in query frontend") if cfg.TraceByID.QueryShards < minQueryShards || cfg.TraceByID.QueryShards > maxQueryShards { @@ -90,8 +90,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo return nil, fmt.Errorf("frontend metrics interval should be greater than 0") } - retryWare := pipeline.NewRetryWare(cfg.MaxRetries, registerer) - + retryWare := pipeline.NewRetryWare(cfg.MaxRetries, cfg.Weights.RetryWithWeights, registerer) cacheWare := pipeline.NewCachingWare(cacheProvider, cache.RoleFrontendSearch, logger) statusCodeWare := pipeline.NewStatusCodeAdjustWare() traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound) @@ -101,6 +100,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo tracePipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, + pipeline.NewWeightRequestWare(pipeline.TraceByID, cfg.Weights), multiTenantMiddleware(cfg, logger), newAsyncTraceIDSharder(&cfg.TraceByID, logger), }, @@ -111,6 +111,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, queryValidatorWare, + pipeline.NewWeightRequestWare(pipeline.TraceQLSearch, cfg.Weights), multiTenantMiddleware(cfg, logger), newAsyncSearchSharder(reader, o, cfg.Search.Sharder, logger), }, @@ -120,6 +121,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo searchTagsPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, + pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights), multiTenantMiddleware(cfg, logger), newAsyncTagSharder(reader, o, cfg.Search.Sharder, parseTagsRequest, logger), }, @@ -129,6 +131,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo searchTagValuesPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, + pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights), multiTenantMiddleware(cfg, logger), newAsyncTagSharder(reader, o, cfg.Search.Sharder, parseTagValuesRequest, logger), }, @@ -140,6 +143,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, queryValidatorWare, + pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights), multiTenantUnsupportedMiddleware(cfg, logger), }, []pipeline.Middleware{statusCodeWare, retryWare}, @@ -150,6 +154,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, queryValidatorWare, + pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights), multiTenantMiddleware(cfg, logger), newAsyncQueryRangeSharder(reader, o, cfg.Metrics.Sharder, logger), }, diff --git a/modules/frontend/metrics_query_range_sharder.go b/modules/frontend/metrics_query_range_sharder.go index 6ffb54dd6af..eb157c5356f 100644 --- a/modules/frontend/metrics_query_range_sharder.go +++ b/modules/frontend/metrics_query_range_sharder.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "math" - "net/http" "time" "github.com/go-kit/log" //nolint:all deprecated @@ -69,7 +68,7 @@ func (s queryRangeSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline return pipeline.NewBadRequest(err), nil } - expr, _, _, _, err := traceql.NewEngine().Compile(req.Query) + expr, _, _, _, err := traceql.Compile(req.Query) if err != nil { return pipeline.NewBadRequest(err), nil } @@ -89,7 +88,7 @@ func (s queryRangeSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline // Note: this is checked after alignment for consistency. maxDuration := s.maxDuration(tenantID) if maxDuration != 0 && time.Duration(req.End-req.Start)*time.Nanosecond > maxDuration { - err = fmt.Errorf(fmt.Sprintf("range specified by start and end (%s) exceeds %s. received start=%d end=%d", time.Duration(req.End-req.Start), maxDuration, req.Start, req.End)) + err = fmt.Errorf("range specified by start and end (%s) exceeds %s. received start=%d end=%d", time.Duration(req.End-req.Start), maxDuration, req.Start, req.End) return pipeline.NewBadRequest(err), nil } @@ -99,14 +98,14 @@ func (s queryRangeSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline cutoff = time.Now().Add(-s.cfg.QueryBackendAfter) ) - generatorReq := s.generatorRequest(*req, r, tenantID, cutoff) + generatorReq := s.generatorRequest(ctx, tenantID, pipelineRequest, *req, cutoff) reqCh := make(chan pipeline.Request, 2) // buffer of 2 allows us to insert generatorReq and metrics if generatorReq != nil { - reqCh <- pipeline.NewHTTPRequest(generatorReq) + reqCh <- generatorReq } - totalJobs, totalBlocks, totalBlockBytes := s.backendRequests(ctx, tenantID, r, *req, cutoff, targetBytesPerRequest, reqCh) + totalJobs, totalBlocks, totalBlockBytes := s.backendRequests(ctx, tenantID, pipelineRequest, *req, cutoff, targetBytesPerRequest, reqCh) span.SetAttributes(attribute.Int64("totalJobs", int64(totalJobs))) span.SetAttributes(attribute.Int64("totalBlocks", int64(totalBlocks))) @@ -158,7 +157,7 @@ func (s *queryRangeSharder) exemplarsPerShard(total uint32) uint32 { return uint32(math.Ceil(float64(s.cfg.MaxExemplars)*1.2)) / total } -func (s *queryRangeSharder) backendRequests(ctx context.Context, tenantID string, parent *http.Request, searchReq tempopb.QueryRangeRequest, cutoff time.Time, targetBytesPerRequest int, reqCh chan pipeline.Request) (totalJobs, totalBlocks uint32, totalBlockBytes uint64) { +func (s *queryRangeSharder) backendRequests(ctx context.Context, tenantID string, parent pipeline.Request, searchReq tempopb.QueryRangeRequest, cutoff time.Time, targetBytesPerRequest int, reqCh chan pipeline.Request) (totalJobs, totalBlocks uint32, totalBlockBytes uint64) { // request without start or end, search only in generator if searchReq.Start == 0 || searchReq.End == 0 { close(reqCh) @@ -204,7 +203,7 @@ func (s *queryRangeSharder) backendRequests(ctx context.Context, tenantID string return } -func (s *queryRangeSharder) buildBackendRequests(ctx context.Context, tenantID string, parent *http.Request, searchReq tempopb.QueryRangeRequest, metas []*backend.BlockMeta, targetBytesPerRequest int, reqCh chan<- pipeline.Request) { +func (s *queryRangeSharder) buildBackendRequests(ctx context.Context, tenantID string, parent pipeline.Request, searchReq tempopb.QueryRangeRequest, metas []*backend.BlockMeta, targetBytesPerRequest int, reqCh chan<- pipeline.Request) { defer close(reqCh) queryHash := hashForQueryRangeRequest(&searchReq) @@ -230,7 +229,7 @@ func (s *queryRangeSharder) buildBackendRequests(ctx context.Context, tenantID s } for startPage := 0; startPage < int(m.TotalRecords); startPage += pages { - subR := parent.Clone(ctx) + subR := parent.HTTPRequest().Clone(ctx) dedColsJSON, err := colsToJSON.JSONForDedicatedColumns(m.DedicatedColumns) if err != nil { @@ -268,7 +267,7 @@ func (s *queryRangeSharder) buildBackendRequests(ctx context.Context, tenantID s subR = api.BuildQueryRangeRequest(subR, queryRangeReq, dedColsJSON) prepareRequestForQueriers(subR, tenantID) - pipelineR := pipeline.NewHTTPRequest(subR) + pipelineR := parent.CloneFromHTTPRequest(subR) // TODO: Handle sampling rate key := queryRangeCacheKey(tenantID, queryHash, int64(queryRangeReq.Start), int64(queryRangeReq.End), m, int(queryRangeReq.StartPage), int(queryRangeReq.PagesToSearch)) @@ -292,9 +291,8 @@ func max(a, b uint32) uint32 { return b } -func (s *queryRangeSharder) generatorRequest(searchReq tempopb.QueryRangeRequest, parent *http.Request, tenantID string, cutoff time.Time) *http.Request { +func (s *queryRangeSharder) generatorRequest(ctx context.Context, tenantID string, parent pipeline.Request, searchReq tempopb.QueryRangeRequest, cutoff time.Time) *pipeline.HTTPRequest { traceql.TrimToAfter(&searchReq, cutoff) - // if start == end then we don't need to query it if searchReq.Start == searchReq.End { return nil @@ -303,12 +301,12 @@ func (s *queryRangeSharder) generatorRequest(searchReq tempopb.QueryRangeRequest searchReq.QueryMode = querier.QueryModeRecent searchReq.Exemplars = uint32(s.cfg.MaxExemplars) // TODO: Review this - subR := parent.Clone(parent.Context()) + subR := parent.HTTPRequest().Clone(ctx) subR = api.BuildQueryRangeRequest(subR, &searchReq, "") // dedicated cols are never passed to the generators prepareRequestForQueriers(subR, tenantID) - return subR + return parent.CloneFromHTTPRequest(subR) } // maxDuration returns the max search duration allowed for this tenant. diff --git a/modules/frontend/pipeline/async_weight_middleware.go b/modules/frontend/pipeline/async_weight_middleware.go new file mode 100644 index 00000000000..7ba18cde2c6 --- /dev/null +++ b/modules/frontend/pipeline/async_weight_middleware.go @@ -0,0 +1,125 @@ +package pipeline + +import ( + "github.com/grafana/tempo/modules/frontend/combiner" + "github.com/grafana/tempo/pkg/traceql" +) + +type RequestType int + +type WeightRequest interface { + SetWeight(int) + Weight() int +} + +type WeightsConfig struct { + RequestWithWeights bool `yaml:"request_with_weights,omitempty"` + RetryWithWeights bool `yaml:"retry_with_weights,omitempty"` + MaxTraceQLConditions int `yaml:"max_traceql_conditions,omitempty"` + MaxRegexConditions int `yaml:"max_regex_conditions,omitempty"` +} + +type Weights struct { + DefaultWeight int + TraceQLSearchWeight int + TraceByIDWeight int + MaxTraceQLConditions int + MaxRegexConditions int +} + +const ( + Default RequestType = iota + TraceByID + TraceQLSearch + TraceQLMetrics +) + +type weightRequestWare struct { + requestType RequestType + enabled bool + next AsyncRoundTripper[combiner.PipelineResponse] + + weights Weights +} + +// It increments the weight of a retriyed request +func IncrementRetriedRequestWeight(r WeightRequest) { + r.SetWeight(r.Weight() + 1) +} + +// It returns a new weight request middleware +func NewWeightRequestWare(rt RequestType, cfg WeightsConfig) AsyncMiddleware[combiner.PipelineResponse] { + weights := Weights{ + DefaultWeight: 1, + TraceQLSearchWeight: 1, + TraceByIDWeight: 2, + MaxTraceQLConditions: cfg.MaxTraceQLConditions, + MaxRegexConditions: cfg.MaxRegexConditions, + } + return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] { + return &weightRequestWare{ + requestType: rt, + enabled: cfg.RequestWithWeights, + weights: weights, + next: next, + } + }) +} + +func (c weightRequestWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) { + c.setWeight(req) + return c.next.RoundTrip(req) +} + +func (c weightRequestWare) setWeight(req Request) { + if !c.enabled { + req.SetWeight(c.weights.DefaultWeight) + return + } + switch c.requestType { + case TraceByID: + req.SetWeight(c.weights.TraceByIDWeight) + case TraceQLSearch, TraceQLMetrics: + c.setTraceQLWeight(req) + default: + req.SetWeight(c.weights.DefaultWeight) + } +} + +func (c weightRequestWare) setTraceQLWeight(req Request) { + var traceQLQuery string + query := req.HTTPRequest().URL.Query() + if query.Has("q") { + traceQLQuery = query.Get("q") + } + if query.Has("query") { + traceQLQuery = query.Get("query") + } + + req.SetWeight(c.weights.TraceQLSearchWeight) + + if traceQLQuery == "" { + return + } + + _, _, _, spanRequest, err := traceql.Compile(traceQLQuery) + if err != nil || spanRequest == nil { + return + } + + conditions := 0 + regexConditions := 0 + + for _, c := range spanRequest.Conditions { + if c.Op != traceql.OpNone { + conditions++ + } + if c.Op == traceql.OpRegex || c.Op == traceql.OpNotRegex { + regexConditions++ + } + } + complexQuery := regexConditions >= c.weights.MaxRegexConditions || conditions >= c.weights.MaxTraceQLConditions + if complexQuery { + req.SetWeight(c.weights.TraceQLSearchWeight + 1) + } +} diff --git a/modules/frontend/pipeline/async_weight_middleware_test.go b/modules/frontend/pipeline/async_weight_middleware_test.go new file mode 100644 index 00000000000..ae4f63f8b5a --- /dev/null +++ b/modules/frontend/pipeline/async_weight_middleware_test.go @@ -0,0 +1,115 @@ +package pipeline + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/grafana/tempo/modules/frontend/combiner" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var nextRequest = AsyncRoundTripperFunc[combiner.PipelineResponse](func(_ Request) (Responses[combiner.PipelineResponse], error) { + return NewHTTPToAsyncResponse(&http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte{})), + }), nil +}) + +const ( + DefaultWeight int = 1 + TraceQLSearchWeight int = 1 + TraceByIDWeight int = 2 +) + +func TestWeightMiddlewareForTraceByIDRequest(t *testing.T) { + config := WeightsConfig{ + RequestWithWeights: true, + } + roundTrip := NewWeightRequestWare(TraceByID, config).Wrap(nextRequest) + req := DoWeightedRequest(t, "http://localhost:8080/api/v2/traces/123345", roundTrip) + + assert.Equal(t, TraceByIDWeight, req.Weight()) +} + +func TestDisabledWeightMiddlewareForTraceByIDRequest(t *testing.T) { + config := WeightsConfig{ + RequestWithWeights: false, + } + roundTrip := NewWeightRequestWare(TraceByID, config).Wrap(nextRequest) + req := DoWeightedRequest(t, "http://localhost:8080/api/v2/traces/123345", roundTrip) + + assert.Equal(t, DefaultWeight, req.Weight()) +} + +func TestWeightMiddlewareForDefaultRequest(t *testing.T) { + config := WeightsConfig{ + RequestWithWeights: true, + } + roundTrip := NewWeightRequestWare(Default, config).Wrap(nextRequest) + req := DoWeightedRequest(t, "http://localhost:8080/api/v2/search/tags", roundTrip) + + assert.Equal(t, DefaultWeight, req.Weight()) +} + +func TestWeightMiddlewareForTraceQLRequest(t *testing.T) { + config := WeightsConfig{ + RequestWithWeights: true, + MaxTraceQLConditions: 4, + MaxRegexConditions: 1, + } + roundTrip := NewWeightRequestWare(TraceQLSearch, config).Wrap(nextRequest) + cases := []struct { + req string + expected int + }{ + { + // Wrong query, this will be catched by the validator middlware + req: "http://localhost:3200/api/search?q={ span.http.status_code }", + expected: TraceQLSearchWeight, + }, + { + // Simple query + req: "http://localhost:3200/api/search?q={ span.http.status_code >= 200 }", + expected: TraceQLSearchWeight, + }, + { + // Simple query + req: "http://localhost:3200/api/search?q={ span.http.status_code >= 200 || span.http.status_code < 300 }", + expected: TraceQLSearchWeight, + }, + { + // Regex, complex query + req: "http://localhost:8080/api/search?query={span.a =~ \"postgresql|mysql\"}", + expected: TraceQLSearchWeight + 1, + }, + { + // Regex, complex query + req: "http://localhost:8080/api/search?query={span.a !~ \"postgresql|mysql\"}", + expected: TraceQLSearchWeight + 1, + }, + { + // 4 conditions, complex query + req: "http://localhost:8080/api/search?query={span.http.method = \"DELETE\" || status != ok || span.http.status_code >= 200 || span.http.status_code < 300 }", + expected: TraceQLSearchWeight + 1, + }, + } + for _, c := range cases { + actual := DoWeightedRequest(t, c.req, roundTrip) + if actual.Weight() != c.expected { + t.Errorf("expected %d, got %d", c.expected, actual.Weight()) + } + } +} + +func DoWeightedRequest(t *testing.T, url string, rt AsyncRoundTripper[combiner.PipelineResponse]) *HTTPRequest { + req, _ := http.NewRequest(http.MethodGet, url, nil) + request := NewHTTPRequest(req) + resp, _ := rt.RoundTrip(request) + _, _, err := resp.Next(context.Background()) + require.NoError(t, err) + return request +} diff --git a/modules/frontend/pipeline/pipeline.go b/modules/frontend/pipeline/pipeline.go index 0d3647b223a..920cd4eb203 100644 --- a/modules/frontend/pipeline/pipeline.go +++ b/modules/frontend/pipeline/pipeline.go @@ -15,11 +15,15 @@ type Request interface { Context() context.Context WithContext(context.Context) + Weight() int + SetWeight(int) + SetCacheKey(string) CacheKey() string SetResponseData(any) // add data that will be sent back with this requests response ResponseData() any + CloneFromHTTPRequest(request *http.Request) *HTTPRequest } type HTTPRequest struct { @@ -27,6 +31,7 @@ type HTTPRequest struct { cacheKey string responseData any + weight int } func NewHTTPRequest(req *http.Request) *HTTPRequest { @@ -65,6 +70,18 @@ func (r *HTTPRequest) ResponseData() any { return r.responseData } +func (r *HTTPRequest) Weight() int { + return r.weight +} + +func (r *HTTPRequest) SetWeight(w int) { + r.weight = w +} + +func (r *HTTPRequest) CloneFromHTTPRequest(request *http.Request) *HTTPRequest { + return &HTTPRequest{req: request, weight: r.weight} +} + // // Async Pipeline // @@ -125,7 +142,7 @@ func (f MiddlewareFunc) Wrap(w RoundTripper) RoundTripper { // // Build takes a slice of async, sync middleware and a http.RoundTripper and builds a request pipeline -func Build(asyncMW []AsyncMiddleware[combiner.PipelineResponse], mw []Middleware, next http.RoundTripper) AsyncRoundTripper[combiner.PipelineResponse] { +func Build(asyncMW []AsyncMiddleware[combiner.PipelineResponse], mw []Middleware, next RoundTripper) AsyncRoundTripper[combiner.PipelineResponse] { asyncPipeline := AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] { for i := len(asyncMW) - 1; i >= 0; i-- { next = asyncMW[i].Wrap(next) @@ -143,7 +160,7 @@ func Build(asyncMW []AsyncMiddleware[combiner.PipelineResponse], mw []Middleware // bridge the two pipelines bridge := &pipelineBridge{ next: syncPipeline.Wrap(RoundTripperFunc(func(req Request) (*http.Response, error) { - return next.RoundTrip(req.HTTPRequest()) + return next.RoundTrip(req) })), convert: NewHTTPToAsyncResponse, } diff --git a/modules/frontend/pipeline/sync_handler_retry.go b/modules/frontend/pipeline/sync_handler_retry.go index 535eed2ce13..b8ad87fed5a 100644 --- a/modules/frontend/pipeline/sync_handler_retry.go +++ b/modules/frontend/pipeline/sync_handler_retry.go @@ -16,7 +16,7 @@ import ( "go.opentelemetry.io/otel/trace" ) -func NewRetryWare(maxRetries int, registerer prometheus.Registerer) Middleware { +func NewRetryWare(maxRetries int, incrementRetriedRequestWeight bool, registerer prometheus.Registerer) Middleware { retriesCount := promauto.With(registerer).NewHistogram(prometheus.HistogramOpts{ Namespace: "tempo", Name: "query_frontend_retries", @@ -29,17 +29,19 @@ func NewRetryWare(maxRetries int, registerer prometheus.Registerer) Middleware { return MiddlewareFunc(func(next RoundTripper) RoundTripper { return retryWare{ - next: next, - maxRetries: maxRetries, - retriesCount: retriesCount, + next: next, + maxRetries: maxRetries, + retriesCount: retriesCount, + incrementRetriedRequestWeight: incrementRetriedRequestWeight, } }) } type retryWare struct { - next RoundTripper - maxRetries int - retriesCount prometheus.Histogram + next RoundTripper + maxRetries int + incrementRetriedRequestWeight bool + retriesCount prometheus.Histogram } // RoundTrip implements http.RoundTripper @@ -61,6 +63,10 @@ func (r retryWare) RoundTrip(req Request) (*http.Response, error) { resp, err := r.next.RoundTrip(req) + if ctx.Err() != nil { + return nil, ctx.Err() + } + if r.maxRetries == 0 { return resp, err } @@ -96,6 +102,12 @@ func (r retryWare) RoundTrip(req Request) (*http.Response, error) { return resp, err } + // retries have their weight bumped. a common retry reason is the request was simply too large to process + // bumping weights should help spread the load + if r.incrementRetriedRequestWeight { + IncrementRetriedRequestWeight(req) + } + statusCode := 0 if resp != nil { statusCode = resp.StatusCode diff --git a/modules/frontend/pipeline/sync_handler_retry_test.go b/modules/frontend/pipeline/sync_handler_retry_test.go index 5d57db7c6d1..6f004c57a33 100644 --- a/modules/frontend/pipeline/sync_handler_retry_test.go +++ b/modules/frontend/pipeline/sync_handler_retry_test.go @@ -109,7 +109,7 @@ func TestRetry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { try.Store(0) - retryWare := NewRetryWare(tc.maxRetries, prometheus.NewRegistry()) + retryWare := NewRetryWare(tc.maxRetries, true, prometheus.NewRegistry()) handler := retryWare.Wrap(tc.handler) req := httptest.NewRequest("GET", "http://example.com", nil) @@ -133,7 +133,7 @@ func TestRetry_CancelledRequest(t *testing.T) { req, err := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) require.NoError(t, err) - _, err = NewRetryWare(5, prometheus.NewRegistry()). + _, err = NewRetryWare(5, false, prometheus.NewRegistry()). Wrap(RoundTripperFunc(func(_ Request) (*http.Response, error) { try.Inc() return nil, ctx.Err() @@ -148,7 +148,7 @@ func TestRetry_CancelledRequest(t *testing.T) { req, err = http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) require.NoError(t, err) - _, err = NewRetryWare(5, prometheus.NewRegistry()). + _, err = NewRetryWare(5, false, prometheus.NewRegistry()). Wrap(RoundTripperFunc(func(_ Request) (*http.Response, error) { try.Inc() cancel() diff --git a/modules/frontend/queue/queue.go b/modules/frontend/queue/queue.go index 75ef6db0bfc..22b0f4d7475 100644 --- a/modules/frontend/queue/queue.go +++ b/modules/frontend/queue/queue.go @@ -39,7 +39,9 @@ func FirstUser() UserIndex { } // Request stored into the queue. -type Request interface{} +type Request interface { + Weight() int +} // RequestQueue holds incoming requests in per-user queues. type RequestQueue struct { @@ -160,18 +162,7 @@ FindQueue: last.last = idx if queue != nil { // this is all threadsafe b/c all users queues are blocked by q.mtx - if len(queue) < requestedCount { - requestedCount = len(queue) - } - - // Pick next requests from the queue. - batchBuffer = batchBuffer[:requestedCount] - for i := 0; i < requestedCount; i++ { - batchBuffer[i] = <-queue - } - - q.queueLength.WithLabelValues(userID).Set(float64(len(queue))) - + batchBuffer := q.getBatchBuffer(batchBuffer, userID, queue) return batchBuffer, last, nil } @@ -181,6 +172,31 @@ FindQueue: goto FindQueue } +func (q *RequestQueue) getBatchBuffer(batchBuffer []Request, userID string, queue chan Request) []Request { + requestedCount := len(batchBuffer) + guaranteedInQueue := requestedCount + + if len(queue) < requestedCount { + guaranteedInQueue = len(queue) + } + + totalWeight := 0 + actuallyInBatch := 0 + for i := 0; i < guaranteedInQueue; i++ { + batchBuffer[i] = <-queue + actuallyInBatch++ + totalWeight += batchBuffer[i].Weight() + + if totalWeight >= requestedCount { + break + } + } + batchBuffer = batchBuffer[:actuallyInBatch] + + q.queueLength.WithLabelValues(userID).Set(float64(len(queue))) + return batchBuffer +} + func (q *RequestQueue) cleanupQueues(_ context.Context) error { q.mtx.Lock() defer q.mtx.Unlock() diff --git a/modules/frontend/queue/queue_test.go b/modules/frontend/queue/queue_test.go index 80d2ea2367d..dc877dc0905 100644 --- a/modules/frontend/queue/queue_test.go +++ b/modules/frontend/queue/queue_test.go @@ -16,9 +16,17 @@ import ( const messages = 50_000 -type mockRequest struct{} +type mockRequest struct { + weight int +} func (r *mockRequest) Invalid() bool { return false } +func (r *mockRequest) Weight() int { + if r.weight > 0 { + return r.weight + } + return 1 +} func TestGetNextForQuerierOneUser(t *testing.T) { messages := 10 @@ -28,7 +36,7 @@ func TestGetNextForQuerierOneUser(t *testing.T) { stop := make(chan struct{}) requestsPulled := atomic.NewInt32(0) - q, start := queueWithListeners(ctx, 100, 1, func(r []Request) { + q, start := queueWithListeners(ctx, 100, 1, func(_ []Request) { i := requestsPulled.Inc() if i == int32(messages) { close(stop) @@ -57,7 +65,7 @@ func TestGetNextForQuerierRandomUsers(t *testing.T) { stop := make(chan struct{}) requestsPulled := atomic.NewInt32(0) - q, start := queueWithListeners(ctx, 100, 1, func(r []Request) { + q, start := queueWithListeners(ctx, 100, 1, func(_ []Request) { if requestsPulled.Inc() == int32(messages) { close(stop) } @@ -125,7 +133,7 @@ func benchmarkGetNextForQuerier(b *testing.B, listeners int, messages int) { stop := make(chan struct{}) requestsPulled := atomic.NewInt32(0) - q, start := queueWithListeners(ctx, listeners, 1, func(r []Request) { + q, start := queueWithListeners(ctx, listeners, 1, func(_ []Request) { if requestsPulled.Inc() == int32(messages) { stop <- struct{}{} } @@ -323,6 +331,66 @@ func TestContextCond(t *testing.T) { }) } +func TestGetBatchBuffer(t *testing.T) { + tests := []struct { + name string + queueContents []Request + requestedCount int + expectedCount int + }{ + { + name: "exactly requested count", + queueContents: []Request{&mockRequest{}, &mockRequest{}, &mockRequest{}}, + requestedCount: 3, + expectedCount: 3, + }, + { + name: "less than requested count", + queueContents: []Request{&mockRequest{}, &mockRequest{}}, + requestedCount: 3, + expectedCount: 2, + }, + { + name: "more than requested count", + queueContents: []Request{&mockRequest{}, &mockRequest{}, &mockRequest{}, &mockRequest{}}, + requestedCount: 3, + expectedCount: 3, + }, + { + name: "less than requested count due to biggest weight", + queueContents: []Request{&mockRequest{10}}, + requestedCount: 3, + expectedCount: 1, + }, + { + name: "empty queue", + queueContents: []Request{}, + requestedCount: 3, + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + queue := make(chan Request, len(tt.queueContents)) + for _, req := range tt.queueContents { + queue <- req + } + + q := &RequestQueue{ + queueLength: prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test_len", + }, []string{"user"}), + } + + batchBuffer := make([]Request, tt.requestedCount) + result := q.getBatchBuffer(batchBuffer, "user", queue) + + assert.Equal(t, tt.expectedCount, len(result)) + }) + } +} + func assertChanReceived(t *testing.T, c chan struct{}, timeout time.Duration, msg string) { t.Helper() diff --git a/modules/frontend/search_handlers_test.go b/modules/frontend/search_handlers_test.go index b81ff9f40d6..05766eb78ba 100644 --- a/modules/frontend/search_handlers_test.go +++ b/modules/frontend/search_handlers_test.go @@ -25,6 +25,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "github.com/grafana/tempo/modules/frontend/pipeline" "github.com/grafana/tempo/modules/overrides" "github.com/grafana/tempo/pkg/api" "github.com/grafana/tempo/pkg/cache" @@ -37,7 +38,7 @@ import ( "github.com/grafana/tempo/tempodb/backend" ) -var _ http.RoundTripper = &mockRoundTripper{} +var _ pipeline.RoundTripper = &mockRoundTripper{} type mockRoundTripper struct { err error @@ -48,7 +49,7 @@ type mockRoundTripper struct { responseFn func() proto.Message } -func (s *mockRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { +func (s *mockRoundTripper) RoundTrip(_ pipeline.Request) (*http.Response, error) { // only return errors once, then do a good response to make sure that the combiner is handling the error correctly var err error var errResponse *http.Response @@ -708,7 +709,7 @@ func BenchmarkSearchPipeline(b *testing.B) { // frontendWithSettings returns a new frontend with the given settings. any nil options // are given "happy path" defaults -func frontendWithSettings(t require.TestingT, next http.RoundTripper, rdr tempodb.Reader, cfg *Config, cacheProvider cache.Provider, +func frontendWithSettings(t require.TestingT, next pipeline.RoundTripper, rdr tempodb.Reader, cfg *Config, cacheProvider cache.Provider, opts ...func(*Config), ) *QueryFrontend { if next == nil { diff --git a/modules/frontend/search_sharder.go b/modules/frontend/search_sharder.go index 9799a18d242..a8e1df35ec8 100644 --- a/modules/frontend/search_sharder.go +++ b/modules/frontend/search_sharder.go @@ -3,7 +3,6 @@ package frontend import ( "context" "fmt" - "net/http" "time" "github.com/go-kit/log" //nolint:all deprecated @@ -96,7 +95,7 @@ func (s asyncSearchSharder) RoundTrip(pipelineRequest pipeline.Request) (pipelin // build request to search ingesters based on query_ingesters_until config and time range // pass subCtx in requests so we can cancel and exit early - err = s.ingesterRequests(ctx, tenantID, r, *searchReq, reqCh) + err = s.ingesterRequests(ctx, tenantID, pipelineRequest, *searchReq, reqCh) if err != nil { return nil, err } @@ -106,7 +105,7 @@ func (s asyncSearchSharder) RoundTrip(pipelineRequest pipeline.Request) (pipelin ingesterJobs := len(reqCh) // pass subCtx in requests so we can cancel and exit early - totalJobs, totalBlocks, totalBlockBytes := s.backendRequests(ctx, tenantID, r, searchReq, reqCh, func(err error) { + totalJobs, totalBlocks, totalBlockBytes := s.backendRequests(ctx, tenantID, pipelineRequest, searchReq, reqCh, func(err error) { // todo: actually find a way to return this error to the user s.logger.Log("msg", "search: failed to build backend requests", "err", err) }) @@ -154,7 +153,7 @@ func (s *asyncSearchSharder) blockMetas(start, end int64, tenantID string) []*ba // backendRequest builds backend requests to search backend blocks. backendRequest takes ownership of reqCh and closes it. // it returns 3 int values: totalBlocks, totalBlockBytes, and estimated jobs -func (s *asyncSearchSharder) backendRequests(ctx context.Context, tenantID string, parent *http.Request, searchReq *tempopb.SearchRequest, reqCh chan<- pipeline.Request, errFn func(error)) (totalJobs, totalBlocks int, totalBlockBytes uint64) { +func (s *asyncSearchSharder) backendRequests(ctx context.Context, tenantID string, parent pipeline.Request, searchReq *tempopb.SearchRequest, reqCh chan<- pipeline.Request, errFn func(error)) (totalJobs, totalBlocks int, totalBlockBytes uint64) { var blocks []*backend.BlockMeta // request without start or end, search only in ingester @@ -200,7 +199,7 @@ func (s *asyncSearchSharder) backendRequests(ctx context.Context, tenantID strin // that covers the ingesters. If nil is returned for the http.Request then there is no ingesters query. // since this function modifies searchReq.Start and End we are taking a value instead of a pointer to prevent it from // unexpectedly changing the passed searchReq. -func (s *asyncSearchSharder) ingesterRequests(ctx context.Context, tenantID string, parent *http.Request, searchReq tempopb.SearchRequest, reqCh chan pipeline.Request) error { +func (s *asyncSearchSharder) ingesterRequests(ctx context.Context, tenantID string, parent pipeline.Request, searchReq tempopb.SearchRequest, reqCh chan pipeline.Request) error { // request without start or end, search only in ingester if searchReq.Start == 0 || searchReq.End == 0 { return buildIngesterRequest(ctx, tenantID, parent, &searchReq, reqCh) @@ -298,7 +297,7 @@ func backendRange(start, end uint32, queryBackendAfter time.Duration) (uint32, u // buildBackendRequests returns a slice of requests that cover all blocks in the store // that are covered by start/end. -func buildBackendRequests(ctx context.Context, tenantID string, parent *http.Request, searchReq *tempopb.SearchRequest, metas []*backend.BlockMeta, bytesPerRequest int, reqCh chan<- pipeline.Request, errFn func(error)) { +func buildBackendRequests(ctx context.Context, tenantID string, parent pipeline.Request, searchReq *tempopb.SearchRequest, metas []*backend.BlockMeta, bytesPerRequest int, reqCh chan<- pipeline.Request, errFn func(error)) { defer close(reqCh) queryHash := hashForSearchRequest(searchReq) @@ -312,7 +311,7 @@ func buildBackendRequests(ctx context.Context, tenantID string, parent *http.Req blockID := m.BlockID.String() for startPage := 0; startPage < int(m.TotalRecords); startPage += pages { - subR := parent.Clone(ctx) + subR := parent.HTTPRequest().Clone(ctx) dedColsJSON, err := colsToJSON.JSONForDedicatedColumns(m.DedicatedColumns) if err != nil { @@ -340,7 +339,7 @@ func buildBackendRequests(ctx context.Context, tenantID string, parent *http.Req prepareRequestForQueriers(subR, tenantID) key := searchJobCacheKey(tenantID, queryHash, int64(searchReq.Start), int64(searchReq.End), m, startPage, pages) - pipelineR := pipeline.NewHTTPRequest(subR) + pipelineR := parent.CloneFromHTTPRequest(subR) pipelineR.SetCacheKey(key) select { @@ -397,14 +396,14 @@ func pagesPerRequest(m *backend.BlockMeta, bytesPerRequest int) int { return pagesPerQuery } -func buildIngesterRequest(ctx context.Context, tenantID string, parent *http.Request, searchReq *tempopb.SearchRequest, reqCh chan pipeline.Request) error { - subR := parent.Clone(ctx) +func buildIngesterRequest(ctx context.Context, tenantID string, parent pipeline.Request, searchReq *tempopb.SearchRequest, reqCh chan pipeline.Request) error { + subR := parent.HTTPRequest().Clone(ctx) subR, err := api.BuildSearchRequest(subR, searchReq) if err != nil { return err } prepareRequestForQueriers(subR, tenantID) - reqCh <- pipeline.NewHTTPRequest(subR) + reqCh <- parent.CloneFromHTTPRequest(subR) return nil } diff --git a/modules/frontend/search_sharder_test.go b/modules/frontend/search_sharder_test.go index e38e587a7ce..64c121655d9 100644 --- a/modules/frontend/search_sharder_test.go +++ b/modules/frontend/search_sharder_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "reflect" "strconv" "strings" "testing" @@ -227,7 +226,7 @@ func TestBuildBackendRequests(t *testing.T) { reqCh := make(chan pipeline.Request) go func() { - buildBackendRequests(ctx, "test", req, searchReq, tc.metas, tc.targetBytesPerRequest, reqCh, cancelCause) + buildBackendRequests(ctx, "test", pipeline.NewHTTPRequest(req), searchReq, tc.metas, tc.targetBytesPerRequest, reqCh, cancelCause) }() actualURIs := []string{} @@ -317,8 +316,8 @@ func TestBackendRequests(t *testing.T) { reqCh := make(chan pipeline.Request) ctx, cancelCause := context.WithCancelCause(context.Background()) - - jobs, blocks, blockBytes := s.backendRequests(ctx, "test", r, searchReq, reqCh, cancelCause) + pipelineRequest := pipeline.NewHTTPRequest(r) + jobs, blocks, blockBytes := s.backendRequests(ctx, "test", pipelineRequest, searchReq, reqCh, cancelCause) require.Equal(t, tc.expectedJobs, jobs) require.Equal(t, tc.expectedBlocks, blocks) require.Equal(t, tc.expectedBlockBytes, blockBytes) @@ -493,8 +492,9 @@ func TestIngesterRequests(t *testing.T) { reqChan := make(chan pipeline.Request, tc.ingesterShards) defer close(reqChan) - copyReq := searchReq - err = s.ingesterRequests(context.Background(), "test", req, *searchReq, reqChan) + pr := pipeline.NewHTTPRequest(req) + pr.SetWeight(2) + err = s.ingesterRequests(context.Background(), "test", pr, *searchReq, reqChan) if tc.expectedError != nil { assert.Equal(t, tc.expectedError, err) continue @@ -541,13 +541,8 @@ func TestIngesterRequests(t *testing.T) { } require.Equal(t, v, values[k]) + require.Equal(t, 2, req.Weight()) } - - /* require.Equal(t, expectedURI, req.RequestURI) */ - - // it may seem odd to test that the searchReq is not modified, but this is to prevent an issue that - // occurs if the ingesterRequest method is changed to take a searchReq pointer - require.True(t, reflect.DeepEqual(copyReq, searchReq)) } } } diff --git a/modules/frontend/tag_sharder.go b/modules/frontend/tag_sharder.go index 8328ad2b58d..47b04af2095 100644 --- a/modules/frontend/tag_sharder.go +++ b/modules/frontend/tag_sharder.go @@ -213,14 +213,14 @@ func (s searchTagSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline. // build request to search ingester based on query_ingesters_until config and time range // pass subCtx in requests, so we can cancel and exit early - ingesterReq, err := s.ingesterRequest(ctx, tenantID, r, searchReq) + ingesterReq, err := s.ingesterRequest(ctx, tenantID, pipelineRequest, searchReq) if err != nil { return nil, err } reqCh := make(chan pipeline.Request, 1) // buffer of 1 allows us to insert ingestReq if it exists if ingesterReq != nil { - reqCh <- pipeline.NewHTTPRequest(ingesterReq) + reqCh <- ingesterReq } s.backendRequests(ctx, tenantID, r, searchReq, reqCh, func(err error) { @@ -318,7 +318,7 @@ func (s searchTagSharder) buildBackendRequests(ctx context.Context, tenantID str // that covers the ingesters. If nil is returned for the http.Request then there is no ingesters query. // we should do a copy of the searchReq before use this function, as it is an interface, we cannot guaranteed be passed // by value. -func (s searchTagSharder) ingesterRequest(ctx context.Context, tenantID string, parent *http.Request, searchReq tagSearchReq) (*http.Request, error) { +func (s searchTagSharder) ingesterRequest(ctx context.Context, tenantID string, parent pipeline.Request, searchReq tagSearchReq) (*pipeline.HTTPRequest, error) { // request without start or end, search only in ingester if searchReq.start() == 0 || searchReq.end() == 0 { return s.buildIngesterRequest(ctx, tenantID, parent, searchReq) @@ -349,14 +349,14 @@ func (s searchTagSharder) ingesterRequest(ctx context.Context, tenantID string, return s.buildIngesterRequest(ctx, tenantID, parent, newSearchReq) } -func (s searchTagSharder) buildIngesterRequest(ctx context.Context, tenantID string, parent *http.Request, searchReq tagSearchReq) (*http.Request, error) { - subR := parent.Clone(ctx) +func (s searchTagSharder) buildIngesterRequest(ctx context.Context, tenantID string, parent pipeline.Request, searchReq tagSearchReq) (*pipeline.HTTPRequest, error) { + subR := parent.HTTPRequest().Clone(ctx) subR, err := searchReq.buildSearchTagRequest(subR) if err != nil { return nil, err } prepareRequestForQueriers(subR, tenantID) - return subR, nil + return parent.CloneFromHTTPRequest(subR), nil } // maxDuration returns the max search duration allowed for this tenant. diff --git a/modules/frontend/tag_sharder_test.go b/modules/frontend/tag_sharder_test.go index dfc1460c000..150cd5e0dc2 100644 --- a/modules/frontend/tag_sharder_test.go +++ b/modules/frontend/tag_sharder_test.go @@ -257,6 +257,7 @@ func TestTagsIngesterRequest(t *testing.T) { } req := httptest.NewRequest("GET", tc.request, nil) + pipelineReq := pipeline.NewHTTPRequest(req) searchReq := fakeReq{ startValue: uint32(tc.start), @@ -264,7 +265,7 @@ func TestTagsIngesterRequest(t *testing.T) { } copyReq := searchReq - actualReq, err := s.ingesterRequest(context.Background(), "test", req, &searchReq) + actualReq, err := s.ingesterRequest(context.Background(), "test", pipelineReq, &searchReq) if tc.expectedError != nil { assert.Equal(t, tc.expectedError, err) continue @@ -273,7 +274,7 @@ func TestTagsIngesterRequest(t *testing.T) { if tc.expectedURI == "" { assert.Nil(t, actualReq) } else { - assert.Equal(t, tc.expectedURI, actualReq.RequestURI) + assert.Equal(t, tc.expectedURI, actualReq.HTTPRequest().RequestURI) } // it may seem odd to test that the searchReq is not modified, but this is to prevent an issue that diff --git a/modules/frontend/traceid_handlers_test.go b/modules/frontend/traceid_handlers_test.go index c2fc4371750..69e16abe036 100644 --- a/modules/frontend/traceid_handlers_test.go +++ b/modules/frontend/traceid_handlers_test.go @@ -14,6 +14,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gorilla/mux" "github.com/grafana/dskit/user" + "github.com/grafana/tempo/modules/frontend/pipeline" "github.com/grafana/tempo/pkg/model/trace" "github.com/grafana/tempo/pkg/tempopb" "github.com/grafana/tempo/pkg/util/test" @@ -159,11 +160,11 @@ func TestTraceIDHandler(t *testing.T) { for _, tc := range tests { tc := tc // copy the test case to prevent race on the loop variable t.Run(tc.name, func(t *testing.T) { - next := RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + next := pipeline.RoundTripperFunc(func(r pipeline.Request) (*http.Response, error) { var testTrace *tempopb.Trace var statusCode int var err error - if r.RequestURI == "/querier/api/traces/1234?mode=ingesters" { + if r.HTTPRequest().RequestURI == "/querier/api/traces/1234?mode=ingesters" { testTrace = tc.trace1 statusCode = tc.status1 err = tc.err1 @@ -236,7 +237,7 @@ func TestTraceIDHandler(t *testing.T) { } func TestTraceIDHandlerForJSONResponse(t *testing.T) { - next := RoundTripperFunc(func(_ *http.Request) (*http.Response, error) { + next := pipeline.RoundTripperFunc(func(_ pipeline.Request) (*http.Response, error) { testTrace := test.MakeTrace(2, []byte{0x01, 0x02}) resBytes, _ := proto.Marshal(&tempopb.TraceByIDResponse{ Trace: testTrace, @@ -354,11 +355,11 @@ func TestTraceIDHandlerV2(t *testing.T) { for _, tc := range tests { tc := tc // copy the test case to prevent race on the loop variable t.Run(tc.name, func(t *testing.T) { - next := RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + next := pipeline.RoundTripperFunc(func(r pipeline.Request) (*http.Response, error) { + var err error var testTrace *tempopb.Trace var statusCode int - var err error - if r.RequestURI == "/querier/api/v2/traces/1234?mode=ingesters" { + if r.HTTPRequest().RequestURI == "/querier/api/v2/traces/1234?mode=ingesters" { testTrace = tc.trace1 statusCode = tc.status1 err = tc.err1 @@ -447,7 +448,7 @@ func TestTraceIDHandlerV2WithJSONResponse(t *testing.T) { } } - next := RoundTripperFunc(func(_ *http.Request) (*http.Response, error) { + next := pipeline.RoundTripperFunc(func(_ pipeline.Request) (*http.Response, error) { var err error resBytes, err := proto.Marshal(&tempopb.TraceByIDResponse{ Trace: splitTrace, diff --git a/modules/frontend/traceid_sharder.go b/modules/frontend/traceid_sharder.go index 458b4bd7c31..21bdfa47d24 100644 --- a/modules/frontend/traceid_sharder.go +++ b/modules/frontend/traceid_sharder.go @@ -66,7 +66,8 @@ func (s asyncTraceSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline } return pipeline.NewAsyncSharderFunc(ctx, int(concurrentShards), len(reqs), func(i int) pipeline.Request { - return pipeline.NewHTTPRequest(reqs[i]) + pipelineReq := pipelineRequest.CloneFromHTTPRequest(reqs[i]) + return pipelineReq }, s.next), nil } diff --git a/modules/frontend/transport/roundtripper.go b/modules/frontend/transport/roundtripper.go deleted file mode 100644 index c8fe2db1489..00000000000 --- a/modules/frontend/transport/roundtripper.go +++ /dev/null @@ -1,56 +0,0 @@ -package transport - -import ( - "bytes" - "context" - "io" - "net/http" - - "github.com/grafana/dskit/httpgrpc" -) - -// GrpcRoundTripper is similar to http.RoundTripper, but works with HTTP requests converted to protobuf messages. -type GrpcRoundTripper interface { - RoundTripGRPC(context.Context, *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) -} - -func AdaptGrpcRoundTripperToHTTPRoundTripper(r GrpcRoundTripper) http.RoundTripper { - return &grpcRoundTripperAdapter{roundTripper: r} -} - -// This adapter wraps GrpcRoundTripper and converted it into http.RoundTripper -type grpcRoundTripperAdapter struct { - roundTripper GrpcRoundTripper -} - -type buffer struct { - buff []byte - io.ReadCloser -} - -func (b *buffer) Bytes() []byte { - return b.buff -} - -func (a *grpcRoundTripperAdapter) RoundTrip(r *http.Request) (*http.Response, error) { - req, err := httpgrpc.FromHTTPRequest(r) - if err != nil { - return nil, err - } - - resp, err := a.roundTripper.RoundTripGRPC(r.Context(), req) - if err != nil { - return nil, err - } - - httpResp := &http.Response{ - StatusCode: int(resp.Code), - Body: &buffer{buff: resp.Body, ReadCloser: io.NopCloser(bytes.NewReader(resp.Body))}, - Header: http.Header{}, - ContentLength: int64(len(resp.Body)), - } - for _, h := range resp.Headers { - httpResp.Header[h.Key] = h.Values - } - return httpResp, nil -} diff --git a/modules/frontend/v1/frontend.go b/modules/frontend/v1/frontend.go index 7f67bd486ea..7a83d95d6c8 100644 --- a/modules/frontend/v1/frontend.go +++ b/modules/frontend/v1/frontend.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "net/http" "sync/atomic" "time" @@ -17,10 +18,10 @@ import ( "github.com/grafana/dskit/httpgrpc" "github.com/grafana/dskit/services" "github.com/grafana/dskit/tenant" - "github.com/grafana/tempo/pkg/util/httpgrpcutil" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/grafana/tempo/modules/frontend/pipeline" "github.com/grafana/tempo/modules/frontend/queue" "github.com/grafana/tempo/modules/frontend/v1/frontendv1pb" "github.com/grafana/tempo/pkg/util" @@ -69,11 +70,18 @@ type Frontend struct { type request struct { enqueueTime time.Time queueSpan trace.Span - originalCtx context.Context - request *httpgrpc.HTTPRequest + request pipeline.Request err chan error - response chan *httpgrpc.HTTPResponse + response chan *http.Response +} + +func (r *request) Weight() int { + return r.request.Weight() +} + +func (r *request) OriginalContext() context.Context { + return r.request.Context() } // New creates a new frontend. Frontend implements service, and must be started and stopped. @@ -163,23 +171,19 @@ func (f *Frontend) cleanupInactiveUserMetrics(user string) { f.discardedRequests.DeleteLabelValues(user) } -// RoundTripGRPC round trips a proto (instead of a HTTP request). -func (f *Frontend) RoundTripGRPC(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) { - // Propagate trace context in gRPC too - this will be ignored if using HTTP. - carrier := (*httpgrpcutil.HttpgrpcHeadersCarrier)(req) - otel.GetTextMapPropagator().Inject(ctx, carrier) - +// RoundTrip a HTTP request +func (f *Frontend) RoundTrip(req pipeline.Request) (*http.Response, error) { request := request{ - request: req, - originalCtx: ctx, + request: req, // Buffer of 1 to ensure response can be written by the server side // of the Process stream, even if this goroutine goes away due to // client context cancellation. err: make(chan error, 1), - response: make(chan *httpgrpc.HTTPResponse, 1), + response: make(chan *http.Response, 1), } + ctx := req.Context() if err := f.queueRequest(ctx, &request); err != nil { return nil, err } @@ -229,11 +233,14 @@ func (f *Frontend) Process(server frontendv1pb.Frontend_ProcessServer) error { req.queueSpan.End() // only add if not expired - if req.originalCtx.Err() != nil { + if req.OriginalContext().Err() != nil { continue } - reqBatch.add(req) + err = reqBatch.add(req) + if err != nil { + return fmt.Errorf("unexpected error adding request to batch: %w", err) + } } // if all requests are expired then continue requesting jobs for this user. this nicely diff --git a/modules/frontend/v1/request_batch.go b/modules/frontend/v1/request_batch.go index 6aae745a20b..e67e820322b 100644 --- a/modules/frontend/v1/request_batch.go +++ b/modules/frontend/v1/request_batch.go @@ -1,10 +1,16 @@ package v1 import ( + "bytes" "fmt" + "io" + "net/http" "github.com/grafana/dskit/httpgrpc" "github.com/grafana/dskit/multierror" + "github.com/grafana/tempo/pkg/util/httpgrpcutil" + + "go.opentelemetry.io/otel" ) type requestBatch struct { @@ -14,14 +20,35 @@ type requestBatch struct { wireRequests []*httpgrpc.HTTPRequest } +type buffer struct { + buff []byte + io.ReadCloser +} + +func (b *buffer) Bytes() []byte { + return b.buff +} + func (b *requestBatch) clear() { b.pipelineRequests = b.pipelineRequests[:0] b.wireRequests = b.wireRequests[:0] } -func (b *requestBatch) add(r *request) { +func (b *requestBatch) add(r *request) error { b.pipelineRequests = append(b.pipelineRequests, r) - b.wireRequests = append(b.wireRequests, r.request) + + req, err := httpgrpc.FromHTTPRequest(r.request.HTTPRequest()) + if err != nil { + return err + } + + // Propagate trace context in gRPC too - this will be ignored if using HTTP. + carrier := (*httpgrpcutil.HttpgrpcHeadersCarrier)(req) + otel.GetTextMapPropagator().Inject(r.OriginalContext(), carrier) + + b.wireRequests = append(b.wireRequests, req) + + return nil } func (b *requestBatch) httpGrpcRequests() []*httpgrpc.HTTPRequest { @@ -36,7 +63,7 @@ func (b *requestBatch) contextError() error { multiErr := multierror.New() for _, r := range b.pipelineRequests { - if err := r.originalCtx.Err(); err != nil { + if err := r.OriginalContext().Err(); err != nil { multiErr.Add(err) } } @@ -52,7 +79,7 @@ func (b *requestBatch) contextError() error { // will belong to the same upstream http query. func (b *requestBatch) doneChan(stop <-chan struct{}) <-chan struct{} { if len(b.pipelineRequests) == 1 { - return b.pipelineRequests[0].originalCtx.Done() + return b.pipelineRequests[0].OriginalContext().Done() } done := make(chan struct{}) @@ -63,7 +90,7 @@ func (b *requestBatch) doneChan(stop <-chan struct{}) <-chan struct{} { // if all are done. for _, r := range b.pipelineRequests { select { - case <-r.originalCtx.Done(): + case <-r.OriginalContext().Done(): case <-stop: return } @@ -87,8 +114,23 @@ func (b *requestBatch) reportResultsToPipeline(responses []*httpgrpc.HTTPRespons } for i, r := range b.pipelineRequests { - r.response <- responses[i] + r.response <- httpGRPCResponseToHTTPResponse(responses[i]) } return nil } + +func httpGRPCResponseToHTTPResponse(resp *httpgrpc.HTTPResponse) *http.Response { + // translate back + httpResp := &http.Response{ + StatusCode: int(resp.Code), + Body: &buffer{buff: resp.Body, ReadCloser: io.NopCloser(bytes.NewReader(resp.Body))}, + Header: http.Header{}, + ContentLength: int64(len(resp.Body)), + } + for _, h := range resp.Headers { + httpResp.Header[h.Key] = h.Values + } + + return httpResp +} diff --git a/modules/frontend/v1/request_batch_test.go b/modules/frontend/v1/request_batch_test.go index ebbbc6e53bf..931abae26c3 100644 --- a/modules/frontend/v1/request_batch_test.go +++ b/modules/frontend/v1/request_batch_test.go @@ -1,12 +1,17 @@ package v1 import ( + "bytes" "context" "errors" + "net/http" + "net/http/httptest" "sync" "testing" "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/tempo/modules/frontend/pipeline" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,10 +21,9 @@ func TestRequestBatchBasics(t *testing.T) { const totalRequests = 3 for i := byte(0); i < totalRequests; i++ { - rb.add(&request{ - request: &httpgrpc.HTTPRequest{ - Body: []byte{i}, - }, + req := httptest.NewRequest("GET", "http://example.com", bytes.NewReader([]byte{i})) + _ = rb.add(&request{ + request: pipeline.NewHTTPRequest(req), }) } @@ -43,21 +47,23 @@ func TestRequestBatchBasics(t *testing.T) { func TestRequestBatchContextError(t *testing.T) { rb := &requestBatch{} - + ctx := context.Background() const totalRequests = 3 - ctx := context.Background() + req := httptest.NewRequest("GET", "http://example.com", nil) + prequest := pipeline.NewHTTPRequest(req) + prequest.WithContext(ctx) + for i := 0; i < totalRequests-1; i++ { - rb.add(&request{ - originalCtx: ctx, - }) + _ = rb.add(&request{request: prequest}) } // add a cancel context cancelCtx, cancel := context.WithCancel(ctx) - rb.add(&request{ - originalCtx: cancelCtx, - }) + prequest = pipeline.NewHTTPRequest(req) + prequest.WithContext(cancelCtx) + + _ = rb.add(&request{request: prequest}) // confirm ok require.NoError(t, rb.contextError()) @@ -74,10 +80,13 @@ func TestDoneChanCloses(_ *testing.T) { ctx := context.Background() cancelCtx, cancel := context.WithCancel(ctx) + + req := httptest.NewRequest("GET", "http://example.com", nil) + prequest := pipeline.NewHTTPRequest(req) + prequest.WithContext(cancelCtx) + for i := 0; i < totalRequests-1; i++ { - rb.add(&request{ - originalCtx: cancelCtx, - }) + _ = rb.add(&request{request: prequest}) } wg := &sync.WaitGroup{} @@ -97,11 +106,11 @@ func TestDoneChanClosesOnStop(_ *testing.T) { rb := &requestBatch{} const totalRequests = 3 + req := httptest.NewRequest("GET", "http://example.com", nil) - ctx := context.Background() for i := 0; i < totalRequests-1; i++ { - rb.add(&request{ - originalCtx: ctx, + _ = rb.add(&request{ + request: pipeline.NewHTTPRequest(req), }) } @@ -134,9 +143,10 @@ func TestErrorsPropagateUpstream(t *testing.T) { require.ErrorContains(t, err, "foo") wg.Done() }() - - rb.add(&request{ - err: errChan, + req := httptest.NewRequest("GET", "http://example.com", nil) + _ = rb.add(&request{ + request: pipeline.NewHTTPRequest(req), + err: errChan, }) } @@ -152,16 +162,18 @@ func TestResponsesPropagateUpstream(t *testing.T) { wg := &sync.WaitGroup{} for i := int32(0); i < totalRequests; i++ { - responseChan := make(chan *httpgrpc.HTTPResponse) + responseChan := make(chan *http.Response) wg.Add(1) go func(expectedCode int32) { resp := <-responseChan - require.Equal(t, expectedCode, resp.Code) + assert.Equal(t, expectedCode, int32(resp.StatusCode)) wg.Done() }(i) - rb.add(&request{ + req := httptest.NewRequest("GET", "http://example.com", nil) + _ = rb.add(&request{ + request: pipeline.NewHTTPRequest(req), response: responseChan, }) } diff --git a/pkg/traceql/engine.go b/pkg/traceql/engine.go index e03aa2d6388..e3288d7e64b 100644 --- a/pkg/traceql/engine.go +++ b/pkg/traceql/engine.go @@ -27,7 +27,7 @@ func NewEngine() *Engine { return &Engine{} } -func (e *Engine) Compile(query string) (*RootExpr, SpansetFilterFunc, metricsFirstStageElement, *FetchSpansRequest, error) { +func Compile(query string) (*RootExpr, SpansetFilterFunc, metricsFirstStageElement, *FetchSpansRequest, error) { expr, err := Parse(query) if err != nil { return nil, nil, nil, nil, err @@ -50,12 +50,13 @@ func (e *Engine) ExecuteSearch(ctx context.Context, searchReq *tempopb.SearchReq ctx, span := tracer.Start(ctx, "traceql.Engine.ExecuteSearch") defer span.End() - rootExpr, err := e.parseQuery(searchReq) + rootExpr, _, _, fetchSpansRequest, err := Compile(searchReq.Query) if err != nil { return nil, err } - fetchSpansRequest := e.createFetchSpansRequest(searchReq, rootExpr.Pipeline) + fetchSpansRequest.StartTimeUnixNanos = unixSecToNano(searchReq.Start) + fetchSpansRequest.EndTimeUnixNanos = unixSecToNano(searchReq.End) span.SetAttributes(attribute.String("pipeline", rootExpr.Pipeline.String())) span.SetAttributes(attribute.String("fetchSpansRequest", fmt.Sprint(fetchSpansRequest))) @@ -99,7 +100,7 @@ func (e *Engine) ExecuteSearch(ctx context.Context, searchReq *tempopb.SearchReq return evalSS, nil } - fetchSpansResponse, err := spanSetFetcher.Fetch(ctx, fetchSpansRequest) + fetchSpansResponse, err := spanSetFetcher.Fetch(ctx, *fetchSpansRequest) if err != nil { return nil, err } @@ -204,30 +205,6 @@ func (e *Engine) ExecuteTagNames( return fetcher.Fetch(ctx, autocompleteReq, cb) } -func (e *Engine) parseQuery(searchReq *tempopb.SearchRequest) (*RootExpr, error) { - r, err := Parse(searchReq.Query) - if err != nil { - return nil, err - } - return r, r.validate() -} - -// createFetchSpansRequest will flatten the SpansetFilter in simple conditions the storage layer -// can work with. -func (e *Engine) createFetchSpansRequest(searchReq *tempopb.SearchRequest, pipeline Pipeline) FetchSpansRequest { - // TODO handle SearchRequest.MinDurationMs and MaxDurationMs, this refers to the trace level duration which is not the same as the intrinsic duration - - req := FetchSpansRequest{ - StartTimeUnixNanos: unixSecToNano(searchReq.Start), - EndTimeUnixNanos: unixSecToNano(searchReq.End), - Conditions: nil, - AllConditions: true, - } - - pipeline.extractConditions(&req) - return req -} - func (e *Engine) createAutocompleteRequest(tag Attribute, pipeline Pipeline) FetchTagValuesRequest { req := FetchSpansRequest{ Conditions: nil, diff --git a/pkg/traceql/engine_metrics.go b/pkg/traceql/engine_metrics.go index 7c711fe8fbf..b94eeabc455 100644 --- a/pkg/traceql/engine_metrics.go +++ b/pkg/traceql/engine_metrics.go @@ -772,7 +772,7 @@ func (e *Engine) CompileMetricsQueryRangeNonRaw(req *tempopb.QueryRangeRequest, return nil, fmt.Errorf("step required") } - _, _, metricsPipeline, _, err := e.Compile(req.Query) + _, _, metricsPipeline, _, err := Compile(req.Query) if err != nil { return nil, fmt.Errorf("compiling query: %w", err) } @@ -806,7 +806,7 @@ func (e *Engine) CompileMetricsQueryRange(req *tempopb.QueryRangeRequest, exempl return nil, fmt.Errorf("step required") } - expr, eval, metricsPipeline, storageReq, err := e.Compile(req.Query) + expr, eval, metricsPipeline, storageReq, err := Compile(req.Query) if err != nil { return nil, fmt.Errorf("compiling query: %w", err) } diff --git a/pkg/traceql/engine_test.go b/pkg/traceql/engine_test.go index 2721b87d409..f7f0599b4e4 100644 --- a/pkg/traceql/engine_test.go +++ b/pkg/traceql/engine_test.go @@ -531,31 +531,23 @@ func TestExamplesInEngine(t *testing.T) { err = yaml.Unmarshal(b, queries) require.NoError(t, err) - e := NewEngine() - for _, q := range queries.Valid { t.Run("valid - "+q, func(t *testing.T) { - _, err := e.parseQuery(&tempopb.SearchRequest{ - Query: q, - }) + _, _, _, _, err := Compile(q) require.NoError(t, err) }) } for _, q := range queries.ParseFails { t.Run("parse fails - "+q, func(t *testing.T) { - _, err := e.parseQuery(&tempopb.SearchRequest{ - Query: q, - }) + _, _, _, _, err := Compile(q) require.Error(t, err) }) } for _, q := range queries.ValidateFails { t.Run("validate fails - "+q, func(t *testing.T) { - _, err := e.parseQuery(&tempopb.SearchRequest{ - Query: q, - }) + _, _, _, _, err := Compile(q) require.Error(t, err) var unErr *unsupportedError require.False(t, errors.As(err, &unErr)) @@ -564,9 +556,7 @@ func TestExamplesInEngine(t *testing.T) { for _, q := range queries.Unsupported { t.Run("unsupported - "+q, func(t *testing.T) { - _, err := e.parseQuery(&tempopb.SearchRequest{ - Query: q, - }) + _, _, _, _, err := Compile(q) require.Error(t, err) var unErr *unsupportedError require.True(t, errors.As(err, &unErr)) diff --git a/pkg/traceqlmetrics/metrics.go b/pkg/traceqlmetrics/metrics.go index 3798f6a078a..427fe3a8327 100644 --- a/pkg/traceqlmetrics/metrics.go +++ b/pkg/traceqlmetrics/metrics.go @@ -226,7 +226,7 @@ func GetMetrics(ctx context.Context, query, groupBy string, spanLimit int, start groupByKeys[i] = groupBys[i][0].String() } - _, eval, _, req, err := traceql.NewEngine().Compile(query) + _, eval, _, req, err := traceql.Compile(query) if err != nil { return nil, fmt.Errorf("compiling query: %w", err) } diff --git a/tempodb/encoding/vparquet4/block_traceql_test.go b/tempodb/encoding/vparquet4/block_traceql_test.go index f4a3a1a3bea..8c81e0d217a 100644 --- a/tempodb/encoding/vparquet4/block_traceql_test.go +++ b/tempodb/encoding/vparquet4/block_traceql_test.go @@ -715,7 +715,7 @@ func TestBackendBlockSelectAll(t *testing.T) { b := makeBackendBlockWithTraces(t, traces) - _, _, _, req, err := traceql.NewEngine().Compile("{}") + _, _, _, req, err := traceql.Compile("{}") require.NoError(t, err) req.SecondPass = func(inSS *traceql.Spanset) ([]*traceql.Spanset, error) { return []*traceql.Spanset{inSS}, nil } req.SecondPassSelectAll = true