From b5046a001fed91ffe4f4da0340cd02df34c67767 Mon Sep 17 00:00:00 2001 From: Santhosh Prabhu <6684582+santhoshmprabhu@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:54:36 -0700 Subject: [PATCH] Make throttling nmagent fetches for nodesubnet more dynamic (#3023) * feat(CNS): Early work on better throttling in NMAgent fetch for nodesubnet * feat(CNS): Update NMAgent fetches to be async with binary exponential backoff * chore: check for empty nmagent response * test: update test for empty response * style: make linter happy * chore: fix some comments * fix: Fix bug in refresh * refactor: Address comments * refactor: ignore primary ip * refactor: move refresh out of ipfetcher * test: add ip fetcher tests * fix: remove broken import * fix: fix import * fix: fix linting * fix: fix some failing tests * chore: Remove unused function * test: test updates * fix: address comments * chore: add missed file * chore: add comment about static interval * feat: address Evan's comment to require Equal method on cached results * chore: add missed file * feat: more efficient equality * refactor: address Evan's comment * refactor: address Tim's comment * fix: undo accidental commit * fix: make linter happy * fix: make linter happy --- cns/nodesubnet/helper_for_ip_fetcher_test.go | 9 -- cns/nodesubnet/ip_fetcher.go | 87 +++++++--- cns/nodesubnet/ip_fetcher_test.go | 131 +++++++++------ nmagent/equality.go | 51 ++++++ nmagent/macaddress.go | 12 ++ refresh/equaler.go | 5 + refresh/fetcher.go | 114 +++++++++++++ refresh/fetcher_test.go | 161 +++++++++++++++++++ refresh/helper_for_fetcher_test.go | 5 + refresh/logger.go | 8 + refresh/mocktickprovider.go | 41 +++++ refresh/refreshticker.go | 37 +++++ 12 files changed, 581 insertions(+), 80 deletions(-) delete mode 100644 cns/nodesubnet/helper_for_ip_fetcher_test.go create mode 100644 nmagent/equality.go create mode 100644 refresh/equaler.go create mode 100644 refresh/fetcher.go create mode 100644 refresh/fetcher_test.go create mode 100644 refresh/helper_for_fetcher_test.go create mode 100644 refresh/logger.go create mode 100644 refresh/mocktickprovider.go create mode 100644 refresh/refreshticker.go diff --git a/cns/nodesubnet/helper_for_ip_fetcher_test.go b/cns/nodesubnet/helper_for_ip_fetcher_test.go deleted file mode 100644 index f8eda641f4..0000000000 --- a/cns/nodesubnet/helper_for_ip_fetcher_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package nodesubnet - -import "time" - -// This method is in this file (_test.go) because it is a test helper method. -// The following method is built during tests, and is not part of the main code. -func (c *IPFetcher) SetSecondaryIPQueryInterval(interval time.Duration) { - c.secondaryIPQueryInterval = interval -} diff --git a/cns/nodesubnet/ip_fetcher.go b/cns/nodesubnet/ip_fetcher.go index 5c2233786d..7457d529e8 100644 --- a/cns/nodesubnet/ip_fetcher.go +++ b/cns/nodesubnet/ip_fetcher.go @@ -7,9 +7,17 @@ import ( "time" "github.com/Azure/azure-container-networking/nmagent" + "github.com/Azure/azure-container-networking/refresh" "github.com/pkg/errors" ) +const ( + // Default minimum time between secondary IP fetches + DefaultMinRefreshInterval = 4 * time.Second + // Default maximum time between secondary IP fetches + DefaultMaxRefreshInterval = 1024 * time.Second +) + var ErrRefreshSkipped = errors.New("refresh skipped due to throttling") // InterfaceRetriever is an interface is implemented by the NMAgent Client, and also a mock client for testing. @@ -17,39 +25,75 @@ type InterfaceRetriever interface { GetInterfaceIPInfo(ctx context.Context) (nmagent.Interfaces, error) } -type IPFetcher struct { - // Node subnet state - secondaryIPQueryInterval time.Duration // Minimum time between secondary IP fetches - secondaryIPLastRefreshTime time.Time // Time of last secondary IP fetch +// IPConsumer is an interface implemented by whoever consumes the secondary IPs fetched in nodesubnet +type IPConsumer interface { + UpdateIPsForNodeSubnet([]netip.Addr) error +} - ipFectcherClient InterfaceRetriever +// IPFetcher fetches secondary IPs from NMAgent at regular intervals. The +// interval will vary within the range of minRefreshInterval and +// maxRefreshInterval. When no diff is observed after a fetch, the interval +// doubles (subject to the maximum interval). When a diff is observed, the +// interval resets to the minimum. +type IPFetcher struct { + // Node subnet config + intfFetcherClient InterfaceRetriever + consumer IPConsumer + fetcher *refresh.Fetcher[nmagent.Interfaces] } -func NewIPFetcher(nmaClient InterfaceRetriever, queryInterval time.Duration) *IPFetcher { - return &IPFetcher{ - ipFectcherClient: nmaClient, - secondaryIPQueryInterval: queryInterval, +// NewIPFetcher creates a new IPFetcher. If minInterval is 0, it will default to 4 seconds. +// If maxInterval is 0, it will default to 1024 seconds (or minInterval, if it is higher). +func NewIPFetcher( + client InterfaceRetriever, + consumer IPConsumer, + minInterval time.Duration, + maxInterval time.Duration, + logger refresh.Logger, +) *IPFetcher { + if minInterval == 0 { + minInterval = DefaultMinRefreshInterval + } + + if maxInterval == 0 { + maxInterval = DefaultMaxRefreshInterval + } + + maxInterval = max(maxInterval, minInterval) + + newIPFetcher := &IPFetcher{ + intfFetcherClient: client, + consumer: consumer, + fetcher: nil, } + fetcher := refresh.NewFetcher[nmagent.Interfaces](client.GetInterfaceIPInfo, minInterval, maxInterval, newIPFetcher.ProcessInterfaces, logger) + newIPFetcher.fetcher = fetcher + return newIPFetcher +} + +// Start the IPFetcher. +func (c *IPFetcher) Start(ctx context.Context) { + c.fetcher.Start(ctx) } -func (c *IPFetcher) RefreshSecondaryIPsIfNeeded(ctx context.Context) (ips []netip.Addr, err error) { - // If secondaryIPQueryInterval has elapsed since the last fetch, fetch secondary IPs - if time.Since(c.secondaryIPLastRefreshTime) < c.secondaryIPQueryInterval { - return nil, ErrRefreshSkipped +// Fetch IPs from NMAgent and pass to the consumer +func (c *IPFetcher) ProcessInterfaces(response nmagent.Interfaces) error { + if len(response.Entries) == 0 { + return errors.New("no interfaces found in response from NMAgent") } - c.secondaryIPLastRefreshTime = time.Now() - response, err := c.ipFectcherClient.GetInterfaceIPInfo(ctx) + _, secondaryIPs := flattenIPListFromResponse(&response) + err := c.consumer.UpdateIPsForNodeSubnet(secondaryIPs) if err != nil { - return nil, errors.Wrap(err, "getting interface IPs") + return errors.Wrap(err, "updating secondary IPs") } - res := flattenIPListFromResponse(&response) - return res, nil + return nil } // Get the list of secondary IPs from fetched Interfaces -func flattenIPListFromResponse(resp *nmagent.Interfaces) (res []netip.Addr) { +func flattenIPListFromResponse(resp *nmagent.Interfaces) (primary netip.Addr, secondaryIPs []netip.Addr) { + var primaryIP netip.Addr // For each interface... for _, intf := range resp.Entries { if !intf.IsPrimary { @@ -63,15 +107,16 @@ func flattenIPListFromResponse(resp *nmagent.Interfaces) (res []netip.Addr) { for _, a := range s.IPAddress { // Primary addresses are reserved for the host. if a.IsPrimary { + primaryIP = netip.Addr(a.Address) continue } - res = append(res, netip.Addr(a.Address)) + secondaryIPs = append(secondaryIPs, netip.Addr(a.Address)) addressCount++ } log.Printf("Got %d addresses from subnet %s", addressCount, s.Prefix) } } - return res + return primaryIP, secondaryIPs } diff --git a/cns/nodesubnet/ip_fetcher_test.go b/cns/nodesubnet/ip_fetcher_test.go index 6a2e425126..b981fd552b 100644 --- a/cns/nodesubnet/ip_fetcher_test.go +++ b/cns/nodesubnet/ip_fetcher_test.go @@ -2,75 +2,102 @@ package nodesubnet_test import ( "context" - "errors" + "net/netip" "testing" - "time" + "github.com/Azure/azure-container-networking/cns/logger" "github.com/Azure/azure-container-networking/cns/nodesubnet" "github.com/Azure/azure-container-networking/nmagent" ) -// Mock client that simply tracks if refresh has been called -type TestClient struct { - fetchCalled bool +// Mock client that simply consumes fetched IPs +type TestConsumer struct { + consumeCount int + secondaryIPCount int } +// FetchConsumeCount atomically fetches the consume count +func (c *TestConsumer) FetchConsumeCount() int { + return c.consumeCount +} + +// FetchSecondaryIPCount atomically fetches the last IP count +func (c *TestConsumer) FetchSecondaryIPCount() int { + return c.consumeCount +} + +// UpdateConsumeCount atomically updates the consume count +func (c *TestConsumer) updateCounts(ipCount int) { + c.consumeCount++ + c.secondaryIPCount = ipCount +} + +// Mock IP update +func (c *TestConsumer) UpdateIPsForNodeSubnet(ips []netip.Addr) error { + c.updateCounts(len(ips)) + return nil +} + +var _ nodesubnet.IPConsumer = &TestConsumer{} + +// Mock client that simply satisfies the interface +type TestClient struct{} + // Mock refresh func (c *TestClient) GetInterfaceIPInfo(_ context.Context) (nmagent.Interfaces, error) { - c.fetchCalled = true return nmagent.Interfaces{}, nil } -func TestRefreshSecondaryIPsIfNeeded(t *testing.T) { - getTests := []struct { - name string - shouldCall bool - interval time.Duration - }{ - { - "fetch called", - true, - -1 * time.Second, // Negative timeout to force refresh - }, - { - "no refresh needed", - false, - 10 * time.Hour, // High timeout to avoid refresh +func TestEmptyResponse(t *testing.T) { + consumerPtr := &TestConsumer{} + fetcher := nodesubnet.NewIPFetcher(&TestClient{}, consumerPtr, 0, 0, logger.Log) + err := fetcher.ProcessInterfaces(nmagent.Interfaces{}) + checkErr(t, err, true) + + // No consumes, since the responses are empty + if consumerPtr.FetchConsumeCount() > 0 { + t.Error("Consume called unexpectedly, shouldn't be called since responses are empty") + } +} + +func TestFlatten(t *testing.T) { + interfaces := nmagent.Interfaces{ + Entries: []nmagent.Interface{ + { + MacAddress: nmagent.MACAddress{0x00, 0x0D, 0x3A, 0xF9, 0xDC, 0xA6}, + IsPrimary: true, + InterfaceSubnets: []nmagent.InterfaceSubnet{ + { + Prefix: "10.240.0.0/16", + IPAddress: []nmagent.NodeIP{ + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 5})), + IsPrimary: true, + }, + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 6})), + IsPrimary: false, + }, + }, + }, + }, + }, }, } + consumerPtr := &TestConsumer{} + fetcher := nodesubnet.NewIPFetcher(&TestClient{}, consumerPtr, 0, 0, logger.Log) + err := fetcher.ProcessInterfaces(interfaces) + checkErr(t, err, false) - clientPtr := &TestClient{} - fetcher := nodesubnet.NewIPFetcher(clientPtr, 0) - - for _, test := range getTests { - test := test - t.Run(test.name, func(t *testing.T) { // Do not parallelize, as we are using a shared client - fetcher.SetSecondaryIPQueryInterval(test.interval) - ctx, cancel := testContext(t) - defer cancel() - clientPtr.fetchCalled = false - _, err := fetcher.RefreshSecondaryIPsIfNeeded(ctx) - - if test.shouldCall { - if err != nil && errors.Is(err, nodesubnet.ErrRefreshSkipped) { - t.Error("refresh expected, but didn't happen") - } - - checkErr(t, err, false) - } else if err == nil || !errors.Is(err, nodesubnet.ErrRefreshSkipped) { - t.Error("refresh not expected, but happened") - } - }) + // 1 consume to be called + if consumerPtr.FetchConsumeCount() != 1 { + t.Error("Consume expected to be called, but not called") } -} -// testContext creates a context from the provided testing.T that will be -// canceled if the test suite is terminated. -func testContext(t *testing.T) (context.Context, context.CancelFunc) { - if deadline, ok := t.Deadline(); ok { - return context.WithDeadline(context.Background(), deadline) + // 1 consume to be called + if consumerPtr.FetchSecondaryIPCount() != 1 { + t.Error("Wrong number of secondary IPs ", consumerPtr.FetchSecondaryIPCount()) } - return context.WithCancel(context.Background()) } // checkErr is an assertion of the presence or absence of an error @@ -84,3 +111,7 @@ func checkErr(t *testing.T, err error, shouldErr bool) { t.Fatal("expected error but received none") } } + +func init() { + logger.InitLogger("testlogs", 0, 0, "./") +} diff --git a/nmagent/equality.go b/nmagent/equality.go new file mode 100644 index 0000000000..67381e9897 --- /dev/null +++ b/nmagent/equality.go @@ -0,0 +1,51 @@ +package nmagent + +// Equal compares two Interfaces objects for equality. +func (i Interfaces) Equal(other Interfaces) bool { + if len(i.Entries) != len(other.Entries) { + return false + } + for idx, entry := range i.Entries { + if !entry.Equal(other.Entries[idx]) { + return false + } + } + return true +} + +// Equal compares two Interface objects for equality. +func (i Interface) Equal(other Interface) bool { + if len(i.InterfaceSubnets) != len(other.InterfaceSubnets) { + return false + } + for idx, subnet := range i.InterfaceSubnets { + if !subnet.Equal(other.InterfaceSubnets[idx]) { + return false + } + } + if i.IsPrimary != other.IsPrimary || !i.MacAddress.Equal(other.MacAddress) { + return false + } + return true +} + +// Equal compares two InterfaceSubnet objects for equality. +func (s InterfaceSubnet) Equal(other InterfaceSubnet) bool { + if len(s.IPAddress) != len(other.IPAddress) { + return false + } + if s.Prefix != other.Prefix { + return false + } + for idx, ip := range s.IPAddress { + if !ip.Equal(other.IPAddress[idx]) { + return false + } + } + return true +} + +// Equal compares two NodeIP objects for equality. +func (ip NodeIP) Equal(other NodeIP) bool { + return ip.IsPrimary == other.IsPrimary && ip.Address.Equal(other.Address) +} diff --git a/nmagent/macaddress.go b/nmagent/macaddress.go index 97c5385162..fa81afc7ef 100644 --- a/nmagent/macaddress.go +++ b/nmagent/macaddress.go @@ -14,6 +14,18 @@ const ( type MACAddress net.HardwareAddr +func (h MACAddress) Equal(other MACAddress) bool { + if len(h) != len(other) { + return false + } + for i := range h { + if h[i] != other[i] { + return false + } + } + return true +} + func (h *MACAddress) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { var macStr string if err := d.DecodeElement(&macStr, &start); err != nil { diff --git a/refresh/equaler.go b/refresh/equaler.go new file mode 100644 index 0000000000..96a42f413e --- /dev/null +++ b/refresh/equaler.go @@ -0,0 +1,5 @@ +package refresh + +type equaler[T any] interface { + Equal(T) bool +} diff --git a/refresh/fetcher.go b/refresh/fetcher.go new file mode 100644 index 0000000000..a509e0dc8c --- /dev/null +++ b/refresh/fetcher.go @@ -0,0 +1,114 @@ +package refresh + +import ( + "context" + "time" +) + +const ( + DefaultMinInterval = 4 * time.Second + DefaultMaxInterval = 1024 * time.Second +) + +// Fetcher fetches data at regular intervals. The interval will vary within the range of minInterval and +// maxInterval. When no diff is observed after a fetch, the interval doubles (subject to the maximum interval). +// When a diff is observed, the interval resets to the minimum. The interval can be made unchanging by setting +// minInterval and maxInterval to the same desired value. + +type Fetcher[T equaler[T]] struct { + fetchFunc func(context.Context) (T, error) + cache T + minInterval time.Duration + maxInterval time.Duration + currentInterval time.Duration + ticker TickProvider + consumeFunc func(T) error + logger Logger +} + +// NewFetcher creates a new Fetcher. If minInterval is 0, it will default to 4 seconds. +func NewFetcher[T equaler[T]]( + fetchFunc func(context.Context) (T, error), + minInterval time.Duration, + maxInterval time.Duration, + consumeFunc func(T) error, + logger Logger, +) *Fetcher[T] { + if minInterval == 0 { + minInterval = DefaultMinInterval + } + + if maxInterval == 0 { + maxInterval = DefaultMaxInterval + } + + maxInterval = max(minInterval, maxInterval) + + return &Fetcher[T]{ + fetchFunc: fetchFunc, + minInterval: minInterval, + maxInterval: maxInterval, + currentInterval: minInterval, + consumeFunc: consumeFunc, + logger: logger, + } +} + +func (f *Fetcher[T]) Start(ctx context.Context) { + go func() { + // do an initial fetch + res, err := f.fetchFunc(ctx) + if err != nil { + f.logger.Printf("Error invoking fetch: %v", err) + } + + f.cache = res + if f.consumeFunc != nil { + if err := f.consumeFunc(res); err != nil { + f.logger.Errorf("Error consuming data: %v", err) + } + } + + if f.ticker == nil { + f.ticker = NewTimedTickProvider(f.currentInterval) + } + + defer f.ticker.Stop() + + for { + select { + case <-ctx.Done(): + f.logger.Printf("Fetcher stopped") + return + case <-f.ticker.C(): + result, err := f.fetchFunc(ctx) + if err != nil { + f.logger.Errorf("Error fetching data: %v", err) + } else { + if result.Equal(f.cache) { + f.updateFetchIntervalForNoObservedDiff() + f.logger.Printf("No diff observed in fetch, not invoking the consumer") + } else { + f.cache = result + f.updateFetchIntervalForObservedDiff() + if f.consumeFunc != nil { + if err := f.consumeFunc(result); err != nil { + f.logger.Errorf("Error consuming data: %v", err) + } + } + } + } + + f.ticker.Reset(f.currentInterval) + } + } + }() +} + +func (f *Fetcher[T]) updateFetchIntervalForNoObservedDiff() { + f.currentInterval = min(f.currentInterval*2, f.maxInterval) // nolint:gomnd // doubling logic +} + +func (f *Fetcher[T]) updateFetchIntervalForObservedDiff() { + f.currentInterval = f.minInterval +} diff --git a/refresh/fetcher_test.go b/refresh/fetcher_test.go new file mode 100644 index 0000000000..0e686a358e --- /dev/null +++ b/refresh/fetcher_test.go @@ -0,0 +1,161 @@ +package refresh_test + +import ( + "context" + "fmt" + "net/netip" + "sync" + "testing" + + "github.com/Azure/azure-container-networking/cns/logger" + "github.com/Azure/azure-container-networking/cns/nodesubnet" + "github.com/Azure/azure-container-networking/nmagent" + "github.com/Azure/azure-container-networking/refresh" +) + +// Mock client that simply tracks if refresh has been called +type TestClient struct { + refreshCount int + responses []nmagent.Interfaces + mu sync.Mutex +} + +// FetchRefreshCount atomically fetches the refresh count +func (c *TestClient) FetchRefreshCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.refreshCount +} + +// UpdateRefreshCount atomically updates the refresh count +func (c *TestClient) UpdateRefreshCount() { + c.mu.Lock() + defer c.mu.Unlock() + c.refreshCount++ +} + +// Mock refresh +func (c *TestClient) GetInterfaceIPInfo(_ context.Context) (nmagent.Interfaces, error) { + defer c.UpdateRefreshCount() + + if c.refreshCount >= len(c.responses) { + return c.responses[len(c.responses)-1], nil + } + + return c.responses[c.refreshCount], nil +} + +var _ nodesubnet.InterfaceRetriever = &TestClient{} + +// Mock client that simply consumes fetched IPs +type TestConsumer struct { + consumeCount int + mu sync.Mutex +} + +// FetchConsumeCount atomically fetches the consume count +func (c *TestConsumer) FetchConsumeCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.consumeCount +} + +// UpdateConsumeCount atomically updates the consume count +func (c *TestConsumer) UpdateConsumeCount() { + c.mu.Lock() + defer c.mu.Unlock() + c.consumeCount++ +} + +// Mock IP update +func (c *TestConsumer) ConsumeInterfaces(intfs nmagent.Interfaces) error { + fmt.Printf("Consumed interfaces: %v\n", intfs) + c.UpdateConsumeCount() + return nil +} + +func TestRefresh(t *testing.T) { + clientPtr := &TestClient{ + refreshCount: 0, + responses: []nmagent.Interfaces{ + { + Entries: []nmagent.Interface{ + { + MacAddress: nmagent.MACAddress{0x00, 0x0D, 0x3A, 0xF9, 0xDC, 0xA6}, + IsPrimary: true, + InterfaceSubnets: []nmagent.InterfaceSubnet{ + { + Prefix: "10.240.0.0/16", + IPAddress: []nmagent.NodeIP{ + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 5})), + IsPrimary: true, + }, + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 6})), + IsPrimary: false, + }, + }, + }, + }, + }, + }, + }, + { + Entries: []nmagent.Interface{ + { + MacAddress: nmagent.MACAddress{0x00, 0x0D, 0x3A, 0xF9, 0xDC, 0xA6}, + IsPrimary: true, + InterfaceSubnets: []nmagent.InterfaceSubnet{ + { + Prefix: "10.240.0.0/16", + IPAddress: []nmagent.NodeIP{ + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 5})), + IsPrimary: true, + }, + }, + }, + }, + }, + }, + }, + }, + mu: sync.Mutex{}, + } + + consumerPtr := &TestConsumer{} + fetcher := refresh.NewFetcher[nmagent.Interfaces](clientPtr.GetInterfaceIPInfo, 0, 0, consumerPtr.ConsumeInterfaces, logger.Log) + ticker := refresh.NewMockTickProvider() + fetcher.SetTicker(ticker) + ctx, cancel := testContext(t) + defer cancel() + fetcher.Start(ctx) + ticker.Tick() // Trigger a refresh + ticker.Tick() // This tick will be read only after previous refresh is done + ticker.Tick() // This call will block until the prevous tick is read + + // At least 2 refreshes - one initial and one after the first tick should be done + if clientPtr.FetchRefreshCount() < 2 { + t.Error("Not enough refreshes") + } + + // Exactly 2 consumes - one initial and one after the first tick should be done (responses are different). + // Then no more, since the response is unchanged + if consumerPtr.FetchConsumeCount() != 2 { + t.Error("Exactly two consumes expected (for two different responses)") + } +} + +// testContext creates a context from the provided testing.T that will be +// canceled if the test suite is terminated. +func testContext(t *testing.T) (context.Context, context.CancelFunc) { + if deadline, ok := t.Deadline(); ok { + return context.WithDeadline(context.Background(), deadline) + } + return context.WithCancel(context.Background()) +} + +func init() { + logger.InitLogger("testlogs", 0, 0, "./") +} diff --git a/refresh/helper_for_fetcher_test.go b/refresh/helper_for_fetcher_test.go new file mode 100644 index 0000000000..fa6a6554eb --- /dev/null +++ b/refresh/helper_for_fetcher_test.go @@ -0,0 +1,5 @@ +package refresh + +func (f *Fetcher[T]) SetTicker(t TickProvider) { + f.ticker = t +} diff --git a/refresh/logger.go b/refresh/logger.go new file mode 100644 index 0000000000..d3a8ea66d0 --- /dev/null +++ b/refresh/logger.go @@ -0,0 +1,8 @@ +package refresh + +type Logger interface { + Debugf(format string, v ...interface{}) + Printf(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + Errorf(format string, v ...interface{}) +} diff --git a/refresh/mocktickprovider.go b/refresh/mocktickprovider.go new file mode 100644 index 0000000000..34b4190b50 --- /dev/null +++ b/refresh/mocktickprovider.go @@ -0,0 +1,41 @@ +package refresh + +import "time" + +// MockTickProvider is a mock implementation of the TickProvider interface +type MockTickProvider struct { + tickChan chan time.Time + currentDuration time.Duration +} + +// NewMockTickProvider creates a new MockTickProvider +func NewMockTickProvider() *MockTickProvider { + return &MockTickProvider{ + tickChan: make(chan time.Time, 1), + } +} + +// C returns the channel on which ticks are delivered +func (m *MockTickProvider) C() <-chan time.Time { + return m.tickChan +} + +// Stop stops the ticker +func (m *MockTickProvider) Stop() { + close(m.tickChan) +} + +// Tick manually sends a tick to the channel +func (m *MockTickProvider) Tick() { + m.tickChan <- time.Now() +} + +func (m *MockTickProvider) Reset(d time.Duration) { + m.currentDuration = d +} + +func (m *MockTickProvider) GetCurrentDuration() time.Duration { + return m.currentDuration +} + +var _ TickProvider = &MockTickProvider{} diff --git a/refresh/refreshticker.go b/refresh/refreshticker.go new file mode 100644 index 0000000000..20ad268718 --- /dev/null +++ b/refresh/refreshticker.go @@ -0,0 +1,37 @@ +package refresh + +import "time" + +// TickProvider defines an interface for a type that provides a channel that ticks at a regular interval +type TickProvider interface { + Stop() + Reset(d time.Duration) + C() <-chan time.Time +} + +// TimedTickProvider wraps a time.Ticker to implement TickProvider +type TimedTickProvider struct { + ticker *time.Ticker +} + +var _ TickProvider = &TimedTickProvider{} + +// NewTimedTickProvider creates a new TimedTickProvider +func NewTimedTickProvider(d time.Duration) *TimedTickProvider { + return &TimedTickProvider{ticker: time.NewTicker(d)} +} + +// Stop stops the ticker +func (tw *TimedTickProvider) Stop() { + tw.ticker.Stop() +} + +// Reset resets the ticker with a new duration +func (tw *TimedTickProvider) Reset(d time.Duration) { + tw.ticker.Reset(d) +} + +// C returns the ticker's channel +func (tw *TimedTickProvider) C() <-chan time.Time { + return tw.ticker.C +}