diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index d08b7cf..eecebbe 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -13,10 +13,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: "stable" @@ -30,3 +30,8 @@ jobs: working-directory: v2/ run: | make test + + - name: Linters + working-directory: v2/ + run: | + make linters diff --git a/v2/Makefile b/v2/Makefile index 1dbf7b3..f14670a 100644 --- a/v2/Makefile +++ b/v2/Makefile @@ -30,8 +30,8 @@ local-dynamodb: wget -O local-dynamodb/latest.zip https://s3.us-west-2.amazonaws.com/dynamodb-local/dynamodb_local_latest.zip (cd local-dynamodb; unzip latest.zip) -test: linters local-dynamodb +test: local-dynamodb GOEXPERIMENT=loopvar go test -v test-race: - GOEXPERIMENT=loopvar go test -race -count=1000 \ No newline at end of file + GOEXPERIMENT=loopvar go test -race -count=1000 diff --git a/v2/client.go b/v2/client.go index 627a435..1382a53 100644 --- a/v2/client.go +++ b/v2/client.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "log" + "runtime" "sync" "time" @@ -276,6 +277,12 @@ func WithAdditionalAttributes(attr map[string]types.AttributeValue) AcquireLockO // Consider an example which uses this mechanism for leader election. One // way to make use of this SessionMonitor is to register a callback that // kills the instance in case the leader's lock enters the danger zone. +// +// The SessionMonitor will not trigger if by the time of its evaluation, the +// lock is already expired. Therefore, you have to tune the lease, the +// heartbeat, and the safe time to reduce the likelihood that the lock will be +// lost at the same time in which the session monitor would be evaluated. A good +// rule of thumb is to have safeTime to be leaseDuration-(3*heartbeatPeriod). func WithSessionMonitor(safeTime time.Duration, callback func()) AcquireLockOption { return func(opt *acquireLockOptions) { opt.sessionMonitor = &sessionMonitor{ @@ -673,22 +680,50 @@ func randString() string { } return base32Encoder.EncodeToString(randomBytes) } - -func (c *Client) heartbeat(ctx context.Context) { - c.logger.Println(ctx, "starting heartbeats") +func (c *Client) heartbeat(rootCtx context.Context) { + c.logger.Println(rootCtx, "heartbeats starting") + defer c.logger.Println(rootCtx, "heartbeats done") tick := time.NewTicker(c.heartbeatPeriod) defer tick.Stop() - for range tick.C { + for { + select { + case <-rootCtx.Done(): + c.logger.Println(rootCtx, "client closed, stopping heartbeat") + return + case t := <-tick.C: + c.logger.Println(rootCtx, "heartbeat at:", t) + } + var ( + wg sync.WaitGroup + maxProcs = runtime.GOMAXPROCS(0) + lockItems = make(chan *Lock, maxProcs) + ) + c.logger.Println(rootCtx, "heartbeat concurrency level:", maxProcs) + for i := 0; i < maxProcs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for lockItem := range lockItems { + c.heartbeatLock(rootCtx, lockItem) + } + }() + } c.locks.Range(func(_ string, lockItem *Lock) bool { - if err := c.SendHeartbeat(lockItem); err != nil { - c.logger.Println(ctx, "error sending heartbeat to", lockItem.partitionKey, ":", err) - } + lockItems <- lockItem return true }) - if ctx.Err() != nil { - c.logger.Println(ctx, "client closed, stopping heartbeat") - return - } + close(lockItems) + c.logger.Println(rootCtx, "all heartbeats are dispatched") + wg.Wait() + c.logger.Println(rootCtx, "all heartbeats are processed") + } +} + +func (c *Client) heartbeatLock(rootCtx context.Context, lockItem *Lock) { + ctx, cancel := context.WithTimeout(rootCtx, c.heartbeatPeriod) + defer cancel() + if err := c.SendHeartbeatWithContext(ctx, lockItem); err != nil { + c.logger.Println(ctx, "error sending heartbeat to", lockItem.partitionKey, ":", err) } } @@ -1030,18 +1065,19 @@ func (c *Client) removeKillSessionMonitor(monitorName string) { cancel() } -func (c *Client) lockSessionMonitorChecker(ctx context.Context, - monitorName string, lock *Lock) { +func (c *Client) lockSessionMonitorChecker(ctx context.Context, monitorName string, lock *Lock) { go func() { defer c.sessionMonitorCancellations.Delete(monitorName) for { lock.semaphore.Lock() - timeUntilDangerZone, err := lock.timeUntilDangerZoneEntered() + isExpired := lock.isExpired() + timeUntilDangerZone := lock.timeUntilDangerZoneEntered() lock.semaphore.Unlock() - if err != nil { - c.logger.Println(ctx, "cannot run session monitor because", err) + if isExpired { + c.logger.Println(ctx, "lock expired", timeUntilDangerZone) return } + c.logger.Println(ctx, "lockSessionMonitorChecker", "monitorName:", monitorName, "timeUntilDangerZone:", timeUntilDangerZone, time.Now().Add(timeUntilDangerZone)) if timeUntilDangerZone <= 0 { go lock.sessionMonitor.callback() return diff --git a/v2/client_heartbeat.go b/v2/client_heartbeat.go index a8e6656..1135b1e 100644 --- a/v2/client_heartbeat.go +++ b/v2/client_heartbeat.go @@ -105,7 +105,9 @@ func (c *Client) SendHeartbeatWithContext(ctx context.Context, lockItem *Lock, o } targetRVN := c.generateRecordVersionNumber() err := c.sendHeartbeat(ctx, sho, currentRVN, targetRVN) - if err != nil { + if errors.Is(err, ctx.Err()) { + return ctx.Err() + } else if err != nil { err = c.retryHeartbeat(ctx, err, sho, currentRVN, targetRVN) err = parseDynamoDBError(err, "already acquired lock, stopping heartbeats") if errors.As(err, new(*LockNotGrantedError)) { diff --git a/v2/client_session_monitor_test.go b/v2/client_session_monitor_test.go index d2baad0..bad6374 100644 --- a/v2/client_session_monitor_test.go +++ b/v2/client_session_monitor_test.go @@ -17,7 +17,11 @@ limitations under the License. package dynamolock_test import ( + "bytes" + "context" "errors" + "fmt" + "log" "sync" "testing" "time" @@ -190,3 +194,119 @@ func TestSessionMonitorFullCycle(t *testing.T) { t.Error("lockedItem should be already expired:", ok, err) } } + +func TestSessionMonitorMissedCall(t *testing.T) { + t.Parallel() + + cases := []struct { + leaseDuration time.Duration + heartbeatPeriod time.Duration + }{ + {6 * time.Second, 1 * time.Second}, + {15 * time.Second, 1 * time.Second}, + {15 * time.Second, 3 * time.Second}, + {20 * time.Second, 5 * time.Second}, + } + for _, tt := range cases { + tt := tt + safeZone := tt.leaseDuration - (3 * tt.heartbeatPeriod) + t.Run(fmt.Sprintf("%s/%s/%s", tt.leaseDuration, tt.heartbeatPeriod, safeZone), func(t *testing.T) { + t.Parallel() + lockName := randStr() + t.Log("lockName:", lockName) + cfg, proxyCloser := proxyConfig(t) + svc := dynamodb.NewFromConfig(cfg) + logger := &bufferedLogger{} + c, err := dynamolock.New(svc, + "locks", + dynamolock.WithLeaseDuration(tt.leaseDuration), + dynamolock.WithOwnerName("TestSessionMonitorMissedCall#1"), + dynamolock.WithHeartbeatPeriod(tt.heartbeatPeriod), + dynamolock.WithPartitionKeyName("key"), + dynamolock.WithLogger(logger), + ) + if err != nil { + t.Fatal(err) + } + + t.Log("ensuring table exists") + _, _ = c.CreateTable("locks", + dynamolock.WithProvisionedThroughput(&types.ProvisionedThroughput{ + ReadCapacityUnits: aws.Int64(5), + WriteCapacityUnits: aws.Int64(5), + }), + dynamolock.WithCustomPartitionKeyName("key"), + ) + + sessionMonitorWasTriggered := make(chan struct{}) + + data := []byte("some content a") + lockedItem, err := c.AcquireLock(lockName, + dynamolock.WithData(data), + dynamolock.ReplaceData(), + dynamolock.WithSessionMonitor(safeZone, func() { + close(sessionMonitorWasTriggered) + }), + ) + if err != nil { + t.Fatal(err) + } + t.Log("lock acquired, closing proxy") + proxyCloser() + t.Log("proxy closed") + + t.Log("waiting", tt.leaseDuration) + select { + case <-time.After(tt.leaseDuration): + t.Error("session monitor was not triggered") + case <-sessionMonitorWasTriggered: + t.Log("session monitor was triggered") + } + t.Log("isExpired", lockedItem.IsExpired()) + + t.Log(logger.String()) + }) + } +} + +type bufferedLogger struct { + mu sync.Mutex + buf bytes.Buffer + logger *log.Logger +} + +func (bl *bufferedLogger) String() string { + bl.mu.Lock() + defer bl.mu.Unlock() + return bl.buf.String() +} + +func (bl *bufferedLogger) Println(a ...any) { + bl.mu.Lock() + defer bl.mu.Unlock() + if bl.logger == nil { + bl.logger = log.New(&bl.buf, "", 0) + } + bl.logger.Println(a...) +} + +type bufferedContextLogger struct { + mu sync.Mutex + buf bytes.Buffer + logger *log.Logger +} + +func (bl *bufferedContextLogger) String() string { + bl.mu.Lock() + defer bl.mu.Unlock() + return bl.buf.String() +} + +func (bl *bufferedContextLogger) Println(_ context.Context, a ...any) { + bl.mu.Lock() + defer bl.mu.Unlock() + if bl.logger == nil { + bl.logger = log.New(&bl.buf, "", 0) + } + bl.logger.Println(a...) +} diff --git a/v2/client_sort_key_test.go b/v2/client_sort_key_test.go index bb0aeba..f0fcca0 100644 --- a/v2/client_sort_key_test.go +++ b/v2/client_sort_key_test.go @@ -296,7 +296,7 @@ func TestSortKeyReadLockContentAfterDeleteOnRelease(t *testing.T) { sortKeyTable, dynamolock.WithLeaseDuration(3*time.Second), dynamolock.WithHeartbeatPeriod(1*time.Second), - dynamolock.WithOwnerName("TestReadLockContentAfterDeleteOnRelease#1"), + dynamolock.WithOwnerName("TestSortKeyReadLockContentAfterDeleteOnRelease#1"), dynamolock.WithPartitionKeyName("key"), dynamolock.WithSortKey("sortkey", "sortvalue"), ) @@ -328,7 +328,7 @@ func TestSortKeyReadLockContentAfterDeleteOnRelease(t *testing.T) { sortKeyTable, dynamolock.WithLeaseDuration(3*time.Second), dynamolock.WithHeartbeatPeriod(1*time.Second), - dynamolock.WithOwnerName("TestReadLockContentAfterDeleteOnRelease#2"), + dynamolock.WithOwnerName("TestSortKeyReadLockContentAfterDeleteOnRelease#2"), dynamolock.WithSortKey("sortkey", "sortvalue"), ) if err != nil { diff --git a/v2/client_test.go b/v2/client_test.go index 1e7ebb1..0063eb2 100644 --- a/v2/client_test.go +++ b/v2/client_test.go @@ -22,7 +22,9 @@ import ( "errors" "flag" "fmt" + "io" "log" + "math/rand" "net" "os" "os/exec" @@ -68,7 +70,8 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } -func defaultConfig(_ *testing.T) aws.Config { +func defaultConfig(t *testing.T) aws.Config { + t.Helper() return aws.Config{ Region: "us-west-2", EndpointResolverWithOptions: aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { @@ -83,20 +86,75 @@ func defaultConfig(_ *testing.T) aws.Config { } } +func proxyConfig(t *testing.T) (aws.Config, func()) { + t.Helper() + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal("cannot start proxy:", err) + } + var ( + outboundConns sync.Map + proxyCloseOnce = sync.OnceFunc(func() { + l.Close() + t.Log("proxy listener stopped") + outboundConns.Range(func(key, value any) bool { + t.Log("proxy connection closed", key) + conn := value.(net.Conn) + conn.Close() + return true + }) + t.Log("proxy stopped") + }) + ) + t.Cleanup(proxyCloseOnce) + go func() { + for { + inboundConn, err := l.Accept() + if err != nil { + return + } + outboundConn, err := net.Dial("tcp4", "localhost:8000") + if err != nil { + return + } + outboundConns.Store(inboundConn.RemoteAddr().String(), outboundConn) + go func() { _, _ = io.Copy(inboundConn, outboundConn) }() + go func() { _, _ = io.Copy(outboundConn, inboundConn) }() + } + }() + return aws.Config{ + Region: "us-west-2", + EndpointResolverWithOptions: aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + return aws.Endpoint{URL: "http://" + l.Addr().String() + "/"}, nil + }), + Credentials: credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: "fakeMyKeyId", + SecretAccessKey: "fakeSecretAccessKey", + }, + }, + RetryMaxAttempts: 1_000_000, + }, proxyCloseOnce +} + func TestClientBasicFlow(t *testing.T) { t.Parallel() svc := dynamodb.NewFromConfig(defaultConfig(t)) + logger := &bufferedContextLogger{} c, err := dynamolock.New(svc, "locks", dynamolock.WithLeaseDuration(3*time.Second), dynamolock.WithHeartbeatPeriod(1*time.Second), dynamolock.WithOwnerName("TestClientBasicFlow#1"), - dynamolock.WithContextLogger(&testContextLogger{t: t}), + dynamolock.WithContextLogger(logger), dynamolock.WithPartitionKeyName("key"), ) if err != nil { t.Fatal(err) } + t.Cleanup(func() { + t.Log(logger.String()) + }) t.Log("ensuring table exists") _, _ = c.CreateTable("locks", @@ -343,6 +401,8 @@ func TestReadLockContentAfterRelease(t *testing.T) { func TestReadLockContentAfterDeleteOnRelease(t *testing.T) { t.Parallel() + lockName := randStr() + svc := dynamodb.NewFromConfig(defaultConfig(t)) c, err := dynamolock.New(svc, "locks", @@ -365,8 +425,8 @@ func TestReadLockContentAfterDeleteOnRelease(t *testing.T) { dynamolock.WithCustomPartitionKeyName("key"), ) - data := []byte("some content for uhura") - lockedItem, err := c.AcquireLock("uhura", + data := []byte("some content for " + lockName) + lockedItem, err := c.AcquireLock(lockName, dynamolock.WithData(data), dynamolock.ReplaceData(), dynamolock.WithDeleteLockOnRelease(), @@ -391,7 +451,7 @@ func TestReadLockContentAfterDeleteOnRelease(t *testing.T) { t.Fatal(err) } - lockItemRead, err := c2.Get("uhura") + lockItemRead, err := c2.Get(lockName) if err != nil { t.Fatal(err) } @@ -734,17 +794,21 @@ func TestClientClose(t *testing.T) { func TestInvalidReleases(t *testing.T) { t.Parallel() svc := dynamodb.NewFromConfig(defaultConfig(t)) + logger := &bufferedLogger{} c, err := dynamolock.New(svc, "locks", dynamolock.WithLeaseDuration(3*time.Second), dynamolock.WithHeartbeatPeriod(1*time.Second), dynamolock.WithOwnerName("TestInvalidReleases#1"), - dynamolock.WithLogger(&testLogger{t: t}), + dynamolock.WithLogger(logger), dynamolock.WithPartitionKeyName("key"), ) if err != nil { t.Fatal(err) } + t.Cleanup(func() { + t.Log(logger.String()) + }) t.Log("ensuring table exists") _, _ = c.CreateTable("locks", @@ -1111,3 +1175,14 @@ func TestTableTags(t *testing.T) { t.Fatal("API request missed tags") } } + +var chars = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randStr() string { + const length = 32 + var b bytes.Buffer + for i := 0; i < length; i++ { + b.WriteByte(chars[rand.Intn(len(chars))]) + } + return b.String() +} diff --git a/v2/lock.go b/v2/lock.go index eb69af5..86c5b28 100644 --- a/v2/lock.go +++ b/v2/lock.go @@ -122,10 +122,13 @@ func (l *Lock) IsAlmostExpired() (bool, error) { } l.semaphore.Lock() defer l.semaphore.Unlock() - t, err := l.timeUntilDangerZoneEntered() - if err != nil { - return false, err + if l.sessionMonitor == nil { + return false, ErrSessionMonitorNotSet + } + if l.isExpired() { + return false, ErrLockAlreadyReleased } + t := l.timeUntilDangerZoneEntered() return t <= 0, nil } @@ -137,12 +140,6 @@ var ( ErrOwnerMismatched = errors.New("lock owner mismatched") ) -func (l *Lock) timeUntilDangerZoneEntered() (time.Duration, error) { - if l.sessionMonitor == nil { - return 0, ErrSessionMonitorNotSet - } - if l.isExpired() { - return 0, ErrLockAlreadyReleased - } - return l.sessionMonitor.timeUntilLeaseEntersDangerZone(l.lookupTime), nil +func (l *Lock) timeUntilDangerZoneEntered() time.Duration { + return l.sessionMonitor.timeUntilLeaseEntersDangerZone(l.lookupTime) }