Skip to content

Commit

Permalink
Merge pull request #211 use safe mutex instead usual mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
rekby authored Sep 14, 2023
2 parents 65eaa18 + 2b2facb commit 0d57fd7
Show file tree
Hide file tree
Showing 22 changed files with 563 additions and 205 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ require (
github.com/jonboulle/clockwork v0.4.0
github.com/letsencrypt/pebble/v2 v2.4.0
github.com/rekby/fastuuid v0.9.0
github.com/rekby/safemutex v0.2.0
golang.org/x/time v0.3.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ github.com/rekby/fastuuid v0.9.0 h1:iQk8V/AyqSrgQAtKRdqx/CVep+CaKwaSWeerw1yEP3Q=
github.com/rekby/fastuuid v0.9.0/go.mod h1:qP8Lh0BH2+4rNGVRDHmDpkvE/ZuLUhjmKpRWjx+WesY=
github.com/rekby/fixenv v0.3.1 h1:zOPocbQmcsxSIjiVu5U+9JAfeu6WeLN7a9ryZkGTGJY=
github.com/rekby/fixenv v0.3.1/go.mod h1:/b5LRc06BYJtslRtHKxsPWFT/ySpHV+rWvzTg+XWk4c=
github.com/rekby/safemutex v0.2.0 h1:iEfcPqsR3EApwWHwdHvp+srN9Wfna+IG8bSpN467Jmk=
github.com/rekby/safemutex v0.2.0/go.mod h1:6I/yJdmctX0RmxEp00RzYBJJXl3ona8PsBiIDqg0v+U=
github.com/rekby/zapcontext v0.0.4 h1:85600nHTteGCLcuOhGp/SzXHymm9QcCA5sn+MPKCodY=
github.com/rekby/zapcontext v0.0.4/go.mod h1:lTIxvHAwWXBZBPPfEvmAEXPbVEcTwd52VaASZWZWcxI=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
Expand Down
208 changes: 118 additions & 90 deletions internal/acme_client_manager/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/rekby/safemutex"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -38,18 +39,21 @@ type AcmeManager struct {
AgreeFunction func(tosurl string) bool
RenewAccountInterval time.Duration

ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
cache cache.Bytes
httpClient *http.Client

background sync.WaitGroup
mu safemutex.MutexWithPointers[acmeManagerSynced]
}

type acmeManagerSynced struct {
lastAccountIndex int
accounts []clientAccount
stateLoaded bool
closed bool
ctxAutorenewCompleted context.Context
cache cache.Bytes
httpClient *http.Client

background sync.WaitGroup
mu sync.Mutex
lastAccountIndex int
accounts []clientAccount
stateLoaded bool
closed bool
}

type clientAccount struct {
Expand All @@ -67,19 +71,22 @@ func New(ctx context.Context, cache cache.Bytes) *AcmeManager {
AgreeFunction: acme.AcceptTOS,
RenewAccountInterval: renewAccountInterval,
httpClient: http.DefaultClient,
lastAccountIndex: -1,
mu: safemutex.NewWithPointers(acmeManagerSynced{lastAccountIndex: -1}),
}
}

func (m *AcmeManager) Close() error {
logger := zc.L(m.ctx)
logger.Debug("Start close")
m.mu.Lock()
alreadyClosed := m.closed
ctxAutorenewCompleted := m.ctxAutorenewCompleted
m.closed = true
m.ctxCancel()
m.mu.Unlock()
var alreadyClosed bool
var ctxAutorenewCompleted context.Context
m.mu.Lock(func(value acmeManagerSynced) (newValue acmeManagerSynced) {
alreadyClosed = value.closed
ctxAutorenewCompleted = value.ctxAutorenewCompleted
value.closed = true
m.ctxCancel()
return value
})
logger.Debug("Set closed flag", zap.Any("autorenew_context", ctxAutorenewCompleted))

if alreadyClosed {
Expand All @@ -95,71 +102,90 @@ func (m *AcmeManager) Close() error {
return nil
}

func (m *AcmeManager) GetClient(ctx context.Context) (_ *acme.Client, disableFunc func(), err error) {
func (m *AcmeManager) GetClient(ctx context.Context) (resClient *acme.Client, disableFunc func(), err error) {
if ctx.Err() != nil {
return nil, nil, errors.New("acme manager context closed")
}

m.mu.Lock()
defer m.mu.Unlock()

if m.closed {
return nil, nil, xerrors.Errorf("GetClient: %w", errClosed)
fail := func(resErr error) {
resClient = nil
disableFunc = nil
err = resErr
}
good := func(c *acme.Client, f func()) {
resClient = c
disableFunc = f
err = nil
}

m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
if synced.closed {
fail(xerrors.Errorf("GetClient: %w", errClosed))
return synced
}

createDisableFunc := func(index int) func() {
return func() {
wasEnabled := m.disableAccountSelfSync(index)
if wasEnabled {
time.AfterFunc(disableDuration, func() {
m.accountEnableSelfSync(index)
})
createDisableFunc := func(index int) func() {
return func() {
wasEnabled := m.disableAccountSelfSync(index)
if wasEnabled {
time.AfterFunc(disableDuration, func() {
m.accountEnableSelfSync(index)
})
}
}
}
}

if !m.stateLoaded && m.cache != nil && !m.IgnoreCacheLoad {
err := m.loadFromCache(ctx)
if err != nil && err != cache.ErrCacheMiss {
return nil, nil, err
if !synced.stateLoaded && m.cache != nil && !m.IgnoreCacheLoad {
err := m.loadFromCacheLocked(ctx, &synced)
if err != nil && !errors.Is(err, cache.ErrCacheMiss) {
fail(err)
return synced
}
synced.stateLoaded = true
}
m.stateLoaded = true
}

if index, ok := m.nextEnabledClientIndex(); ok {
return m.accounts[index].client, createDisableFunc(index), nil
}
if index, ok := m.nextEnabledClientIndexLocked(&synced); ok {
good(synced.accounts[index].client, createDisableFunc(index))
return synced
}

acc, err := m.registerAccount(ctx)
m.accounts = append(m.accounts, acc)
acc, err := m.registerAccount(ctx)
synced.accounts = append(synced.accounts, acc)

m.background.Add(1)
// handlepanic: in accountRenewSelfSync
go func(index int) {
defer m.background.Done()
m.accountRenewSelfSync(index)
}(len(m.accounts) - 1)
m.background.Add(1)
// handlepanic: in accountRenewSelfSync
go func(index int) {
defer m.background.Done()
m.accountRenewSelfSync(index)
}(len(synced.accounts) - 1)

if err != nil {
return nil, nil, err
}
if err != nil {
fail(err)
return synced
}

if err = m.saveState(ctx); err != nil {
return nil, nil, err
}
if err = m.saveStateLocked(ctx, &synced); err != nil {
fail(err)
return synced
}

return acc.client, createDisableFunc(len(m.accounts) - 1), nil
good(acc.client, createDisableFunc(len(synced.accounts)-1))
return synced
})
return resClient, disableFunc, err
}

func (m *AcmeManager) accountRenewSelfSync(index int) {
logger := zc.L(m.ctx)
ctx, ctxCancel := context.WithCancel(m.ctx)
defer ctxCancel()

m.mu.Lock()
m.ctxAutorenewCompleted = ctx
acc := m.accounts[index]
m.mu.Unlock()
var acc clientAccount
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.ctxAutorenewCompleted = ctx
acc = synced.accounts[index]
return synced
})

if m.ctx.Err() != nil {
return
Expand All @@ -183,37 +209,39 @@ func (m *AcmeManager) accountRenewSelfSync(index int) {
newAccount = renewTos(m.ctx, acc.client, acc.account)
}()
acc.account = newAccount
m.mu.Lock()
m.accounts[index] = acc
m.mu.Unlock()
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.accounts[index] = acc
return synced
})
}
}
}

func (m *AcmeManager) disableAccountSelfSync(index int) (wasEnabled bool) {
m.mu.Lock()
defer m.mu.Unlock()

if m.accounts[index].enabled {
m.accounts[index].enabled = false
return true
}

return false
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
if synced.accounts[index].enabled {
synced.accounts[index].enabled = false
wasEnabled = true
return synced
}
wasEnabled = false
return synced
})
return wasEnabled
}

func (m *AcmeManager) accountEnableSelfSync(index int) {
m.mu.Lock()
defer m.mu.Unlock()

m.accounts[index].enabled = true
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.accounts[index].enabled = true
return synced
})
}

func (m *AcmeManager) initClient() *acme.Client {
return &acme.Client{DirectoryURL: m.DirectoryURL, HTTPClient: m.httpClient}
}

func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) {
func (m *AcmeManager) loadFromCacheLocked(ctx context.Context, synced *acmeManagerSynced) (err error) {
defer func() {
var effectiveError error
if err == cache.ErrCacheMiss {
Expand All @@ -239,7 +267,7 @@ func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) {
return xerrors.Errorf("no accounts in state")
}

m.accounts = make([]clientAccount, 0, len(state.Accounts))
synced.accounts = make([]clientAccount, 0, len(state.Accounts))
for index, stateAccount := range state.Accounts {
client := m.initClient()
client.Key = stateAccount.PrivateKey
Expand All @@ -255,34 +283,34 @@ func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) {
defer m.background.Done()
m.accountRenewSelfSync(index)
}(index)
m.accounts = append(m.accounts, acc)
synced.accounts = append(synced.accounts, acc)
}

return nil
}

func (m *AcmeManager) nextEnabledClientIndex() (int, bool) {
func (m *AcmeManager) nextEnabledClientIndexLocked(synced *acmeManagerSynced) (int, bool) {
switch {
case len(m.accounts) == 0:
case len(synced.accounts) == 0:
return 0, false
case len(m.accounts) == 1 && m.accounts[0].enabled:
case len(synced.accounts) == 1 && synced.accounts[0].enabled:
return 0, true
default:
// pass
}

startIndex := m.lastAccountIndex
startIndex := synced.lastAccountIndex
if startIndex < 0 {
startIndex = len(m.accounts) - 1
startIndex = len(synced.accounts) - 1
}
index := startIndex
for {
index++
if index >= len(m.accounts) {
if index >= len(synced.accounts) {
index = 0
}
if m.accounts[index].enabled {
m.lastAccountIndex = index
if synced.accounts[index].enabled {
synced.lastAccountIndex = index
return index, true
}
if index == startIndex {
Expand Down Expand Up @@ -310,11 +338,11 @@ func (m *AcmeManager) registerAccount(ctx context.Context) (clientAccount, error
return acc, nil
}

func (m *AcmeManager) saveState(ctx context.Context) error {
func (m *AcmeManager) saveStateLocked(ctx context.Context, synced *acmeManagerSynced) error {
var state acmeManagerState
state.Accounts = make([]acmeAccountState, 0, len(m.accounts))
state.Accounts = make([]acmeAccountState, 0, len(synced.accounts))

for _, acc := range m.accounts {
for _, acc := range synced.accounts {
state.Accounts = append(state.Accounts, acmeAccountState{PrivateKey: acc.client.Key.(*rsa.PrivateKey), AcmeAccount: acc.account})
}

Expand Down
24 changes: 13 additions & 11 deletions internal/acme_client_manager/client_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,19 @@ func TestClientManager_nextEnabledClientIndex(t *testing.T) {
e, _, flush := th.NewEnv(t)
defer flush()

m := AcmeManager{
lastAccountIndex: test.lastAccountIndex,
}

for _, enabled := range test.accountsEnabled {
m.accounts = append(m.accounts, clientAccount{enabled: enabled})
}

resIndex, resOk := m.nextEnabledClientIndex()
e.Cmp(resIndex, test.resIndex)
e.Cmp(resOk, test.resOk)
m := AcmeManager{}

m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.lastAccountIndex = test.lastAccountIndex
for _, enabled := range test.accountsEnabled {
synced.accounts = append(synced.accounts, clientAccount{enabled: enabled})
}

resIndex, resOk := m.nextEnabledClientIndexLocked(&synced)
e.Cmp(resIndex, test.resIndex)
e.Cmp(resOk, test.resOk)
return synced
})
})
}
}
Expand Down
Loading

0 comments on commit 0d57fd7

Please sign in to comment.