From 12880e59038577db55a2297dc77f419f9e5b3047 Mon Sep 17 00:00:00 2001 From: Blake Rouse Date: Tue, 18 Jul 2023 16:25:11 -0400 Subject: [PATCH] Fix tests. --- .../application/dispatcher/dispatcher_test.go | 14 +++++++++----- .../artifact/download/http/downloader_test.go | 18 ++++++++++++++---- pkg/component/runtime/log_writer.go | 16 +++++++++++----- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/internal/pkg/agent/application/dispatcher/dispatcher_test.go b/internal/pkg/agent/application/dispatcher/dispatcher_test.go index 15cb194bb71..fb3cbb1af9b 100644 --- a/internal/pkg/agent/application/dispatcher/dispatcher_test.go +++ b/internal/pkg/agent/application/dispatcher/dispatcher_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -227,7 +226,11 @@ func TestActionDispatcher(t *testing.T) { t.Run("Cancel queued action", func(t *testing.T) { def := &mockHandler{} - def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + calledCh := make(chan bool) + call := def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + call.RunFn = func(_ mock.Arguments) { + calledCh <- true + } queue := &mockQueue{} queue.On("Save").Return(nil).Once() @@ -248,10 +251,11 @@ func TestActionDispatcher(t *testing.T) { select { case err := <-d.Errors(): t.Fatalf("Unexpected error: %v", err) - case <-time.After(200 * time.Microsecond): - // we're not expecting any reset, + case <-calledCh: + // Handle was called, expected + case <-time.After(1 * time.Second): + t.Fatal("mock Handle never called") } - assert.Eventuallyf(t, func() bool { return len(def.Calls) > 0 }, 100*time.Millisecond, 100*time.Microsecond, "mock handler for cancel actions has not been called") def.AssertExpectations(t) queue.AssertExpectations(t) }) diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go index 11784e2d0f5..a49c9b6d154 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go @@ -30,19 +30,20 @@ func TestDownloadBodyError(t *testing.T) { // part way through the download, while copying the response body. type connKey struct{} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() conn, ok := r.Context().Value(connKey{}).(net.Conn) if ok { - conn.Close() + _ = conn.Close() } })) - defer srv.Close() - client := srv.Client() srv.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { return context.WithValue(ctx, connKey{}, c) } + srv.Start() + defer srv.Close() + client := srv.Client() targetDir, err := ioutil.TempDir(os.TempDir(), "") if err != nil { @@ -64,6 +65,9 @@ func TestDownloadBodyError(t *testing.T) { t.Fatal("expected Download to return an error") } + log.lock.RLock() + defer log.lock.RUnlock() + require.GreaterOrEqual(t, len(log.info), 1, "download error not logged at info level") assert.True(t, containsMessage(log.info, "download from %s failed at %s @ %sps: %s")) require.GreaterOrEqual(t, len(log.warn), 1, "download error not logged at warn level") @@ -113,6 +117,9 @@ func TestDownloadLogProgressWithLength(t *testing.T) { os.Remove(artifactPath) require.NoError(t, err, "Download should not have errored") + log.lock.RLock() + defer log.lock.RUnlock() + // 2 files are downloaded so 4 log messages are expected in the info level and only the complete is over the warn // window as 2 log messages for warn. require.Len(t, log.info, 4) @@ -167,6 +174,9 @@ func TestDownloadLogProgressWithoutLength(t *testing.T) { os.Remove(artifactPath) require.NoError(t, err, "Download should not have errored") + log.lock.RLock() + defer log.lock.RUnlock() + // 2 files are downloaded so 4 log messages are expected in the info level and only the complete is over the warn // window as 2 log messages for warn. require.Len(t, log.info, 4) diff --git a/pkg/component/runtime/log_writer.go b/pkg/component/runtime/log_writer.go index ee277c26fff..960c7f07a1a 100644 --- a/pkg/component/runtime/log_writer.go +++ b/pkg/component/runtime/log_writer.go @@ -36,8 +36,9 @@ type logWriter struct { loggerCore zapcoreWriter logCfg component.CommandLogSpec logLevel zap.AtomicLevel + + mx sync.Mutex unitLevels map[string]zapcore.Level - levelMx sync.RWMutex remainder []byte // inheritLevel is the level that will be used for a log message in the case it doesn't define a log level @@ -60,9 +61,10 @@ func newLogWriter(core zapcoreWriter, logCfg component.CommandLogSpec, ll zapcor } func (r *logWriter) SetLevels(ll zapcore.Level, unitLevels map[string]zapcore.Level) { + // must hold to lock so Write doesn't access the unitLevels + r.mx.Lock() + defer r.mx.Unlock() r.logLevel.SetLevel(ll) - r.levelMx.Lock() - defer r.levelMx.Unlock() r.unitLevels = unitLevels } @@ -71,6 +73,12 @@ func (r *logWriter) Write(p []byte) (int, error) { // nothing to do return 0, nil } + + // hold the lock so SetLevels and the remainder is not touched + // from multiple go routines + r.mx.Lock() + defer r.mx.Unlock() + offset := 0 for { idx := bytes.IndexByte(p[offset:], '\n') @@ -127,13 +135,11 @@ func (r *logWriter) handleJSON(line string) bool { allowedLvl := r.logLevel.Level() unitId := getUnitId(evt) if unitId != "" { - r.levelMx.RLock() if r.unitLevels != nil { if unitLevel, ok := r.unitLevels[unitId]; ok { allowedLvl = unitLevel } } - r.levelMx.RUnlock() } if allowedLvl.Enabled(lvl) { _ = r.loggerCore.Write(zapcore.Entry{