From c0aebe6350f1e09b825a2efbe8ded3f53e3acbbf Mon Sep 17 00:00:00 2001 From: Ethan Mosbaugh Date: Wed, 16 Oct 2024 11:53:24 -0700 Subject: [PATCH] fix race condition in applier-manager Signed-off-by: Ethan Mosbaugh --- pkg/applier/manager.go | 90 ++++++++++++++++++++++++++++--------- pkg/applier/manager_test.go | 6 +-- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/pkg/applier/manager.go b/pkg/applier/manager.go index 78a3f14e1c9c..e0f430d924f1 100644 --- a/pkg/applier/manager.go +++ b/pkg/applier/manager.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "path" + "sync" "time" "github.com/k0sproject/k0s/internal/pkg/dir" @@ -39,11 +40,13 @@ type Manager struct { KubeClientFactory kubeutil.ClientFactoryInterface // client kubernetes.Interface - applier Applier - bundlePath string - cancelWatcher context.CancelFunc - log *logrus.Entry - stacks map[string]stack + applier Applier + bundlePath string + stacks map[string]stack + log *logrus.Entry + startChan chan struct{} + mux sync.Mutex + watcherCancelFn context.CancelFunc LeaderElector leaderelector.Interface } @@ -67,35 +70,80 @@ func (m *Manager) Init(ctx context.Context) error { m.applier = NewApplier(m.K0sVars.ManifestsDir, m.KubeClientFactory) - m.LeaderElector.AddAcquiredLeaseCallback(func() { - watcherCtx, cancel := context.WithCancel(ctx) - m.cancelWatcher = cancel - go func() { - _ = m.runWatchers(watcherCtx) - }() - }) - m.LeaderElector.AddLostLeaseCallback(func() { - if m.cancelWatcher != nil { - m.cancelWatcher() - } - }) - - return err + return nil } // Run runs the Manager func (m *Manager) Start(_ context.Context) error { + m.log.Debug("Starting") + m.startChan = make(chan struct{}, 1) + + m.LeaderElector.AddLostLeaseCallback(m.leaseLost) + + m.LeaderElector.AddAcquiredLeaseCallback(m.leaseAcquired) + + // It's possible that by the time we added the callback, we are already the leader, + // If this is true the callback will not be called, so we need to check if we are + // the leader and notify the channel manually + if m.LeaderElector.IsLeader() { + m.leaseAcquired() + } + + go m.watchStartChan() return nil } +func (m *Manager) watchStartChan() { + m.log.Debug("Watching start channel") + for range m.startChan { + m.log.Info("Acquired leader lease") + m.mux.Lock() + ctx, cancel := context.WithCancel(context.Background()) + // If there is a previous cancel func, call it + if m.watcherCancelFn != nil { + m.watcherCancelFn() + } + m.watcherCancelFn = cancel + m.mux.Unlock() + _ = m.runWatchers(ctx) + } + m.log.Info("Start channel closed, stopping applier-manager") +} + // Stop stops the Manager func (m *Manager) Stop() error { - if m.cancelWatcher != nil { - m.cancelWatcher() + m.log.Info("Stopping applier-manager") + // We have no guarantees on concurrency here, so use mutex + m.mux.Lock() + watcherCancelFn := m.watcherCancelFn + m.mux.Unlock() + if watcherCancelFn != nil { + watcherCancelFn() } + close(m.startChan) + m.log.Debug("Stopped applier-manager") return nil } +func (m *Manager) leaseLost() { + m.mux.Lock() + defer m.mux.Unlock() + m.log.Warn("Lost leader lease, stopping applier-manager") + + watcherCancelFn := m.watcherCancelFn + if watcherCancelFn != nil { + watcherCancelFn() + } +} + +func (m *Manager) leaseAcquired() { + m.log.Info("Acquired leader lease") + select { + case m.startChan <- struct{}{}: + default: + } +} + func (m *Manager) runWatchers(ctx context.Context) error { log := logrus.WithField("component", constant.ApplierManagerComponentName) diff --git a/pkg/applier/manager_test.go b/pkg/applier/manager_test.go index cc34cf28ce9c..2eb2b9799fb4 100644 --- a/pkg/applier/manager_test.go +++ b/pkg/applier/manager_test.go @@ -20,6 +20,7 @@ import ( "context" "embed" "os" + "path" "path/filepath" "sync" "testing" @@ -184,14 +185,13 @@ func waitFor(t *testing.T, interval, timeout time.Duration, fn wait.ConditionWit } func writeStack(t *testing.T, dst string, src string) { - dstStackDir := filepath.Join(dst, filepath.Base(src)) + dstStackDir := filepath.Join(dst, path.Base(src)) err := os.MkdirAll(dstStackDir, 0755) require.NoError(t, err) entries, err := managerTestData.ReadDir(src) require.NoError(t, err) for _, entry := range entries { - src := filepath.Join(src, entry.Name()) - data, err := managerTestData.ReadFile(src) + data, err := managerTestData.ReadFile(path.Join(src, entry.Name())) require.NoError(t, err) dst := filepath.Join(dstStackDir, entry.Name()) t.Logf("writing file %s", dst)