Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: check ratelimit before validation #2659

Merged
merged 4 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,33 +259,24 @@ func (d *Distributor) PushParsed(ctx context.Context, req *distributormodel.Push
d.metrics.receivedCompressedBytes.WithLabelValues(string(profName), tenantID).Observe(float64(req.RawProfileSize))
}

if err := d.rateLimit(tenantID, req); err != nil {
return nil, err
}

for _, series := range req.Series {
// include the labels in the size calculation
for _, lbs := range series.Labels {
req.TotalBytesUncompressed += int64(len(lbs.Name))
req.TotalBytesUncompressed += int64(len(lbs.Value))
}
profName := phlaremodel.Labels(series.Labels).Get(ProfileName)
for _, raw := range series.Samples {
usagestats.NewCounter(fmt.Sprintf("distributor_profile_type_%s_received", profName)).Inc(1)
d.profileReceivedStats.Inc(1)
if haveRawPprof {
d.metrics.receivedCompressedBytes.WithLabelValues(profName, tenantID).Observe(float64(len(raw.RawProfile)))
}
req.TotalProfiles++
p := raw.Profile
var decompressedSize int
if haveRawPprof {
decompressedSize = p.SizeBytes()
} else {
decompressedSize = p.SizeVT()
}
decompressedSize := p.SizeVT()
d.metrics.receivedDecompressedBytes.WithLabelValues(profName, tenantID).Observe(float64(decompressedSize))
d.metrics.receivedSamples.WithLabelValues(profName, tenantID).Observe(float64(len(p.Sample)))
req.TotalBytesUncompressed += int64(decompressedSize)

if err = validation.ValidateProfile(d.limits, tenantID, p.Profile, decompressedSize, series.Labels, now); err != nil {
// todo this actually discards more if multiple Samples in a Series request
_ = level.Debug(d.logger).Log("msg", "invalid profile", "err", err)
validation.DiscardedProfiles.WithLabelValues(string(validation.ReasonOf(err)), tenantID).Add(float64(req.TotalProfiles))
validation.DiscardedBytes.WithLabelValues(string(validation.ReasonOf(err)), tenantID).Add(float64(req.TotalBytesUncompressed))
Expand All @@ -302,15 +293,6 @@ func (d *Distributor) PushParsed(ctx context.Context, req *distributormodel.Push
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("no profiles received"))
}

// rate limit the request
if !d.ingestionRateLimiter.AllowN(time.Now(), tenantID, int(req.TotalBytesUncompressed)) {
validation.DiscardedProfiles.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalProfiles))
validation.DiscardedBytes.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalBytesUncompressed))
return nil, connect.NewError(connect.CodeResourceExhausted,
fmt.Errorf("push rate limit (%s) exceeded while adding %s", humanize.IBytes(uint64(d.limits.IngestionRateBytes(tenantID))), humanize.IBytes(uint64(req.TotalBytesUncompressed))),
)
}

// Normalisation is quite an expensive operation,
// therefore it should be done after the rate limit check.
for _, series := range req.Series {
Expand Down Expand Up @@ -684,6 +666,29 @@ func (d *Distributor) limitMaxSessionsPerSeries(tenantID string, labels phlaremo
return labels
}

func (d *Distributor) rateLimit(tenantID string, req *distributormodel.PushRequest) error {
for _, series := range req.Series {
// include the labels in the size calculation
for _, lbs := range series.Labels {
req.TotalBytesUncompressed += int64(len(lbs.Name))
req.TotalBytesUncompressed += int64(len(lbs.Value))
}
for _, raw := range series.Samples {
req.TotalProfiles += 1
req.TotalBytesUncompressed += int64(raw.Profile.SizeVT())
}
}
// rate limit the request
if !d.ingestionRateLimiter.AllowN(time.Now(), tenantID, int(req.TotalBytesUncompressed)) {
validation.DiscardedProfiles.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalProfiles))
validation.DiscardedBytes.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalBytesUncompressed))
return connect.NewError(connect.CodeResourceExhausted,
fmt.Errorf("push rate limit (%s) exceeded while adding %s", humanize.IBytes(uint64(d.limits.IngestionRateBytes(tenantID))), humanize.IBytes(uint64(req.TotalBytesUncompressed))),
)
}
return nil
}

// mergeSeriesAndSampleLabels merges sample labels with
// series labels. Series labels take precedence.
func mergeSeriesAndSampleLabels(p *googlev1.Profile, sl []*typesv1.LabelPair, pl []*googlev1.Label) []*typesv1.LabelPair {
Expand Down
94 changes: 88 additions & 6 deletions pkg/distributor/distributor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"github.com/grafana/dskit/ring/client"
"github.com/grafana/dskit/services"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

testhelper2 "github.com/grafana/pyroscope/pkg/pprof/testhelper"

profilev1 "github.com/grafana/pyroscope/api/gen/proto/go/google/v1"
distributormodel "github.com/grafana/pyroscope/pkg/distributor/model"
phlaremodel "github.com/grafana/pyroscope/pkg/model"
Expand Down Expand Up @@ -171,6 +174,18 @@ func collectTestProfileBytes(t *testing.T) []byte {
return buf.Bytes()
}

func hugeProfileBytes(t *testing.T) []byte {
t.Helper()
b := testhelper2.NewProfileBuilderWithLabels(time.Now().UnixNano(), nil)
p := b.CPUProfile()
for i := 0; i < 10_000; i++ {
p.ForStacktraceString(fmt.Sprintf("my_%d", i), "other").AddSamples(1)
}
bs, err := p.Profile.MarshalVT()
require.NoError(t, err)
return bs
}

type fakeIngester struct {
t testing.TB
requests []*pushv1.PushRequest
Expand Down Expand Up @@ -203,10 +218,11 @@ func TestBuckets(t *testing.T) {

func Test_Limits(t *testing.T) {
type testCase struct {
description string
pushReq *pushv1.PushRequest
overrides *validation.Overrides
expectedCode connect.Code
description string
pushReq *pushv1.PushRequest
overrides *validation.Overrides
expectedCode connect.Code
expectedValidationReason validation.Reason
}

testCases := []testCase{
Expand Down Expand Up @@ -234,7 +250,32 @@ func Test_Limits(t *testing.T) {
l.IngestionBurstSizeMB = 0.0015
tenantLimits["user-1"] = l
}),
expectedCode: connect.CodeResourceExhausted,
expectedCode: connect.CodeResourceExhausted,
expectedValidationReason: validation.RateLimited,
},
{
description: "rate_limit_invalid_profile",
pushReq: &pushv1.PushRequest{
Series: []*pushv1.RawProfileSeries{
{
Labels: []*typesv1.LabelPair{
{Name: "__name__", Value: "cpu"},
{Name: phlaremodel.LabelNameServiceName, Value: "svc"},
},
Samples: []*pushv1.RawSample{{
RawProfile: hugeProfileBytes(t),
}},
},
},
},
overrides: validation.MockOverrides(func(defaults *validation.Limits, tenantLimits map[string]*validation.Limits) {
l := validation.MockDefaultLimits()
l.IngestionBurstSizeMB = 0.0015
l.MaxProfileStacktraceSamples = 100
tenantLimits["user-1"] = l
}),
expectedCode: connect.CodeResourceExhausted,
expectedValidationReason: validation.RateLimited,
},
{
description: "labels_limit",
Expand All @@ -244,6 +285,7 @@ func Test_Limits(t *testing.T) {
Labels: []*typesv1.LabelPair{
{Name: "clusterdddwqdqdqdqdqdqw", Value: "us-central1"},
{Name: "__name__", Value: "cpu"},
{Name: phlaremodel.LabelNameServiceName, Value: "svc"},
},
Samples: []*pushv1.RawSample{
{
Expand All @@ -258,13 +300,15 @@ func Test_Limits(t *testing.T) {
l.MaxLabelNameLength = 12
tenantLimits["user-1"] = l
}),
expectedCode: connect.CodeInvalidArgument,
expectedCode: connect.CodeInvalidArgument,
expectedValidationReason: validation.LabelNameTooLong,
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {

mux := http.NewServeMux()
ing := newFakeIngester(t, false)
d, err := New(Config{
Expand All @@ -276,6 +320,13 @@ func Test_Limits(t *testing.T) {
}}, tc.overrides, nil, log.NewLogfmtLogger(os.Stdout))

require.NoError(t, err)

expectedMetricDelta := map[prometheus.Collector]float64{
validation.DiscardedBytes.WithLabelValues(string(tc.expectedValidationReason), "user-1"): float64(uncompressedProfileSize(t, tc.pushReq)),
//todo make sure pyroscope_distributor_received_decompressed_bytes_sum is not incremented
}
m1 := metricsDump(expectedMetricDelta)

mux.Handle(pushv1connect.NewPusherServiceHandler(d, connect.WithInterceptors(tenant.NewAuthInterceptor(true))))
s := httptest.NewServer(mux)
defer s.Close()
Expand All @@ -285,6 +336,7 @@ func Test_Limits(t *testing.T) {
require.Error(t, err)
require.Equal(t, tc.expectedCode, connect.CodeOf(err))
require.Nil(t, resp)
expectMetricsChange(t, m1, metricsDump(expectedMetricDelta), expectedMetricDelta)
})
}
}
Expand Down Expand Up @@ -972,3 +1024,33 @@ func testProfile(t int64) *profilev1.Profile {
Period: 10000000,
}
}

func uncompressedProfileSize(t *testing.T, req *pushv1.PushRequest) int {
var size int
for _, s := range req.Series {
for _, label := range s.Labels {
size += len(label.Name) + len(label.Value)
}
for _, sample := range s.Samples {
p, err := pprof2.RawFromBytes(sample.RawProfile)
require.NoError(t, err)
size += p.SizeVT()
}
}
return size
}

func metricsDump(metrics map[prometheus.Collector]float64) map[prometheus.Collector]float64 {
res := make(map[prometheus.Collector]float64)
for m := range metrics {
res[m] = testutil.ToFloat64(m)
}
return res
}

func expectMetricsChange(t *testing.T, m1, m2, expectedChange map[prometheus.Collector]float64) {
for counter, expectedDelta := range expectedChange {
delta := m2[counter] - m1[counter]
assert.Equal(t, expectedDelta, delta, "metric %s", counter)
}
}
Loading