diff --git a/cmd/jaeger/internal/extension/remotesampling/config.go b/cmd/jaeger/internal/extension/remotesampling/config.go index c19f3ab48e0..edff4707e26 100644 --- a/cmd/jaeger/internal/extension/remotesampling/config.go +++ b/cmd/jaeger/internal/extension/remotesampling/config.go @@ -21,6 +21,7 @@ var ( ) var ( + _ component.Config = (*Config)(nil) _ component.ConfigValidator = (*Config)(nil) _ confmap.Unmarshaler = (*Config)(nil) ) diff --git a/cmd/jaeger/internal/extension/remotesampling/extension_test.go b/cmd/jaeger/internal/extension/remotesampling/extension_test.go index 5e1105c07c4..de8cb300e24 100644 --- a/cmd/jaeger/internal/extension/remotesampling/extension_test.go +++ b/cmd/jaeger/internal/extension/remotesampling/extension_test.go @@ -33,26 +33,6 @@ import ( "github.com/jaegertracing/jaeger/proto-gen/api_v2" ) -type samplingHost struct { - t *testing.T - samplingExtension component.Component -} - -func (host samplingHost) GetExtensions() map[component.ID]component.Component { - return map[component.ID]component.Component{ - ID: host.samplingExtension, - } -} - -func (host samplingHost) ReportFatalError(err error) { - host.t.Fatal(err) -} - -func (samplingHost) GetFactory(_ component.Kind, _ component.Type) component.Factory { return nil } -func (samplingHost) GetExporters() map[component.DataType]map[component.ID]component.Component { - return nil -} - func makeStorageExtension(t *testing.T, memstoreName string) component.Host { telemetrySettings := component.TelemetrySettings{ Logger: zaptest.NewLogger(t), @@ -80,9 +60,7 @@ func makeStorageExtension(t *testing.T, memstoreName string) component.Host { return host } -var _ component.Config = (*Config)(nil) - -func makeRemoteSamplingExtension(t *testing.T, cfg component.Config) samplingHost { +func makeRemoteSamplingExtension(t *testing.T, cfg component.Config) component.Host { extensionFactory := NewFactory() samplingExtension, err := extensionFactory.CreateExtension( context.Background(), @@ -95,11 +73,10 @@ func makeRemoteSamplingExtension(t *testing.T, cfg component.Config) samplingHos cfg, ) require.NoError(t, err) - host := samplingHost{t: t, samplingExtension: samplingExtension} + host := storagetest.NewStorageHost().WithExtension(ID, samplingExtension) storageHost := makeStorageExtension(t, "foobar") - err = samplingExtension.Start(context.Background(), storageHost) - require.NoError(t, err) + require.NoError(t, samplingExtension.Start(context.Background(), storageHost)) t.Cleanup(func() { require.NoError(t, samplingExtension.Shutdown(context.Background())) }) return host } @@ -210,6 +187,19 @@ func TestStartAdaptiveProvider(t *testing.T) { require.NoError(t, ext.Shutdown(context.Background())) } +func TestStartAdaptiveStrategyProviderErrors(t *testing.T) { + host := storagetest.NewStorageHost() + ext := &rsExtension{ + cfg: &Config{ + Adaptive: &AdaptiveConfig{ + SamplingStore: "foobar", + }, + }, + } + err := ext.startAdaptiveStrategyProvider(host) + require.ErrorContains(t, err, "cannot find storage factory") +} + func TestGetAdaptiveSamplingComponents(t *testing.T) { // Success case host := makeRemoteSamplingExtension(t, &Config{ @@ -230,11 +220,25 @@ func TestGetAdaptiveSamplingComponents(t *testing.T) { assert.Equal(t, time.Duration(1), comps.Options.FollowerLeaseRefreshInterval) assert.Equal(t, time.Duration(1), comps.Options.LeaderLeaseRefreshInterval) assert.Equal(t, 1, comps.Options.AggregationBuckets) +} + +type wrongExtension struct{} + +func (*wrongExtension) Start(context.Context, component.Host) error { return nil } +func (*wrongExtension) Shutdown(context.Context) error { return nil } - // Error case - host = makeRemoteSamplingExtension(t, &Config{}) - _, err = GetAdaptiveSamplingComponents(host) +func TestGetAdaptiveSamplingComponentsErrors(t *testing.T) { + host := makeRemoteSamplingExtension(t, &Config{}) + _, err := GetAdaptiveSamplingComponents(host) require.ErrorContains(t, err, "extension 'remote_sampling' is not configured for adaptive sampling") + + h1 := storagetest.NewStorageHost() + _, err = GetAdaptiveSamplingComponents(h1) + require.ErrorContains(t, err, "cannot find extension") + + h2 := h1.WithExtension(ID, &wrongExtension{}) + _, err = GetAdaptiveSamplingComponents(h2) + require.ErrorContains(t, err, "is not of type") } func TestDependencies(t *testing.T) {