Skip to content

Commit

Permalink
Clean up runtime manager test initialization (#3150)
Browse files Browse the repository at this point in the history
The runtime manager uses a mysterious lookup table, `waitReady`, that is
injected into the TLS authentication callbacks to allow a special list
of unregistered components to connect. However this turns out to only
ever be used by the unit tests for a single connection, which itself is
only used to probe whether the RPC server is listening yet. This can be
done more simply by just setting an atomic flag when the server loop
starts, so this PR does that and removes the extra synchronization
baggage.

Has no functional effect except when using the unit test helper
`waitForReady`, and should effectively still be a no-op there (just with
fewer redundant network connections).

## Checklist

- [x] My code follows the style guidelines of this project
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] ~~I have made corresponding changes to the documentation~~
- [ ] ~~I have made corresponding change to the default configuration
files~~
- [ ] ~~I have added tests that prove my fix is effective or that my
feature works~~
- [ ] ~~I have added an entry in `./changelog/fragments` using the
[changelog tool](https://github.com/elastic/elastic-agent#changelog)~~
- [ ] ~~I have added an integration test or an E2E test~~
  • Loading branch information
faec authored Aug 1, 2023
1 parent f6a4f65 commit 539a5f2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 106 deletions.
92 changes: 4 additions & 88 deletions pkg/component/runtime/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ type Manager struct {
listener net.Listener
server *grpc.Server

// waitMx synchronizes the access to waitReady only
waitMx sync.RWMutex
waitReady map[string]waitForReady
// Set when the RPC server is ready to receive requests, for use by tests.
serverReady *atomic.Bool

// updateMx protects the call to update to ensure that
// only one call to update occurs at a time
Expand Down Expand Up @@ -151,13 +150,13 @@ func NewManager(
listenAddr: listenAddr,
agentInfo: agentInfo,
tracer: tracer,
waitReady: make(map[string]waitForReady),
current: make(map[string]*componentRuntimeState),
shipperConns: make(map[string]*shipperConn),
subscriptions: make(map[string][]*Subscription),
errCh: make(chan error),
monitor: monitor,
grpcConfig: grpcConfig,
serverReady: atomic.NewBool(false),
}
return m, nil
}
Expand Down Expand Up @@ -216,6 +215,7 @@ func (m *Manager) Run(ctx context.Context) error {
wg.Add(1)
go func() {
defer wg.Done()
m.serverReady.Store(true)
for {
err := server.Serve(lis)
if err != nil {
Expand All @@ -242,73 +242,6 @@ func (m *Manager) Run(ctx context.Context) error {
return ctx.Err()
}

// waitForReady waits until the manager is ready to be used.
// Used for testing.
//
// This verifies that the GRPC server is up and running.
func (m *Manager) waitForReady(ctx context.Context) error {
tk, err := uuid.NewV4()
if err != nil {
return err
}
token := tk.String()
name, err := genServerName()
if err != nil {
return err
}
pair, err := m.ca.GeneratePairWithName(name)
if err != nil {
return err
}
cert, err := tls.X509KeyPair(pair.Crt, pair.Key)
if err != nil {
return err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(m.ca.Crt())
trans := credentials.NewTLS(&tls.Config{
ServerName: name,
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
})

m.waitMx.Lock()
m.waitReady[token] = waitForReady{
name: name,
cert: pair,
}
m.waitMx.Unlock()

defer func() {
m.waitMx.Lock()
delete(m.waitReady, token)
m.waitMx.Unlock()
}()

for {
m.netMx.RLock()
lis := m.listener
srv := m.server
m.netMx.RUnlock()
if lis != nil && srv != nil {
addr := m.getListenAddr()
c, err := grpc.Dial(addr, grpc.WithTransportCredentials(trans))
if err == nil {
_ = c.Close()
return nil
}
}

t := time.NewTimer(100 * time.Millisecond)
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
}
}
}

// Errors returns channel that errors are reported on.
func (m *Manager) Errors() <-chan error {
return m.errCh
Expand Down Expand Up @@ -917,18 +850,6 @@ func (m *Manager) getCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, er
return cert, nil
}

m.waitMx.RLock()
for _, waiter := range m.waitReady {
if waiter.name == chi.ServerName {
cert = waiter.cert.Certificate
break
}
}
m.waitMx.RUnlock()
if cert != nil {
return cert, nil
}

return nil, errors.New("no supported TLS certificate")
}

Expand Down Expand Up @@ -1057,8 +978,3 @@ func (m *Manager) performDiagAction(ctx context.Context, comp component.Componen
}
return res.Diagnostic, nil
}

type waitForReady struct {
name string
cert *authority.Pair
}
48 changes: 30 additions & 18 deletions pkg/component/runtime/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestManager_SimpleComponentErr(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -186,7 +186,7 @@ func TestManager_FakeInput_StartStop(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -313,7 +313,7 @@ func TestManager_FakeInput_Features(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -502,7 +502,7 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -668,7 +668,7 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -824,7 +824,7 @@ func TestManager_FakeInput_NoDeadlock(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -958,7 +958,7 @@ func TestManager_FakeInput_Configure(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1078,7 +1078,7 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1231,7 +1231,7 @@ func TestManager_FakeInput_ActionState(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1355,7 +1355,7 @@ func TestManager_FakeInput_Restarts(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1490,7 +1490,7 @@ func TestManager_FakeInput_Restarts_ConfigKill(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1632,7 +1632,7 @@ func TestManager_FakeInput_KeepsRestarting(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1774,7 +1774,7 @@ func TestManager_FakeInput_RestartsOnMissedCheckins(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -1889,7 +1889,7 @@ func TestManager_FakeInput_InvalidAction(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -2014,7 +2014,7 @@ func TestManager_FakeInput_MultiComponent(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -2227,7 +2227,7 @@ func TestManager_FakeInput_LogLevel(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -2371,7 +2371,7 @@ func TestManager_FakeShipper(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -2672,7 +2672,7 @@ func TestManager_FakeInput_OutputChange(t *testing.T) {

waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second)
defer waitCancel()
if err := m.waitForReady(waitCtx); err != nil {
if err := waitForReady(waitCtx, m); err != nil {
require.NoError(t, err)
}

Expand Down Expand Up @@ -2998,3 +2998,15 @@ func newTestMonitoringMgr() *testMonitoringManager { return &testMonitoringManag
func (*testMonitoringManager) EnrichArgs(_ string, _ string, args []string) []string { return args }
func (*testMonitoringManager) Prepare(_ string) error { return nil }
func (*testMonitoringManager) Cleanup(string) error { return nil }

// waitForReady waits until the RPC server is ready to be used.
func waitForReady(ctx context.Context, m *Manager) error {
for !m.serverReady.Load() {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(100 * time.Millisecond):
}
}
return nil
}

0 comments on commit 539a5f2

Please sign in to comment.