Skip to content

Commit

Permalink
Lock certificates
Browse files Browse the repository at this point in the history
  • Loading branch information
rekby authored May 1, 2019
2 parents 985bb96 + 4970c9a commit fa915b2
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 42 deletions.
8 changes: 4 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ before_deploy:
- git config --local user.name "$GIT_NAME"
- git config --local user.email "$GIT_EMAIL"
- export TRAVIS_TAG="$TAG_PREFIX.$TRAVIS_BUILD_NUMBER"
- BUILD_TIME=$(TZ=UTC date --rfc-3339=seconds)
- BUILD_TIME=$(TZ=UTC date --rfc-3339=seconds | cut -d '+' -f 1 | tr -d - | tr ' ' '-' | tr -d :)
- git tag $TRAVIS_TAG
- go get -mod= github.com/mitchellh/gox
- mkdir -p output
- OS_ARCH_BUILDS="darwin/amd64 linux/386 linux/amd64 linux/arm freebsd/386 freebsd/amd64 freebsd/arm windows/386 windows/amd64"
- gox --mod=vendor -osarch "$OS_ARCH_BUILDS" --ldflags "-X \"main.VERSION=$TRAVIS_TAG commit $TRAVIS_COMMIT builded '$BUILD_TIME' buildNumber $TRAVIS_BUILD_NUMBER\"" --output="output/lets-proxy_{{.OS}}_{{.Arch}}" -verbose --rebuild ./cmd/
- gox --mod=vendor -osarch "$OS_ARCH_BUILDS" --ldflags "-X \"main.VERSION=$TRAVIS_TAG+build-$TRAVIS_BUILD_NUMBER, Build time $BUILD_TIME, commit $TRAVIS_COMMIT\"" --output="output/lets-proxy_{{.OS}}_{{.Arch}}" -verbose --rebuild ./cmd/
- bash tests/make_archives.sh

deploy:
skip_cleanup: true
provider: releases
draft: true
tags: true
on:
repo: rekby/lets-proxy2
branch: master
tags: true
api_key: $GITHUB_TOKEN
file_glob: true
file: output/*
Expand Down
18 changes: 17 additions & 1 deletion internal/cert_manager/cert-state.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type certState struct {
issueContext context.Context // nil if no issue process now
issueContextCancel func()
cert *tls.Certificate
locked bool
lastError error
}

Expand Down Expand Up @@ -102,10 +103,25 @@ func (s *certState) Cert() (cert *tls.Certificate, lastError error) {
return cert, lastError
}

func (s *certState) CertSet(ctx context.Context, cert *tls.Certificate) {
func (s *certState) CertSet(ctx context.Context, locked bool, cert *tls.Certificate) {
zc.L(ctx).Debug("Store certificate in local state", log.Cert(cert))

s.mu.Lock()
s.cert = cert
s.locked = locked
s.lastError = nil
s.mu.Unlock()
}

func (s *certState) SetLocked() {
s.mu.Lock()
s.locked = true
s.mu.Unlock()
}

func (s *certState) GetLocked() bool {
s.mu.RLock()
defer s.mu.RUnlock()

return s.locked
}
10 changes: 7 additions & 3 deletions internal/cert_manager/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func flatByteSlices(in [][]byte) []byte {
}

// Return valid parced certificate or error
func validCertDer(domains []DomainName, der [][]byte, key crypto.PrivateKey, now time.Time) (cert *tls.Certificate, err error) {
func validCertDer(domains []DomainName, der [][]byte, key crypto.PrivateKey, locked bool, now time.Time) (cert *tls.Certificate, err error) {
// parse public part(s)
x509Cert, err := x509.ParseCertificates(flatByteSlices(der))
if err != nil || len(x509Cert) == 0 {
Expand All @@ -68,10 +68,10 @@ func validCertDer(domains []DomainName, der [][]byte, key crypto.PrivateKey, now
Leaf: leaf,
}

return validCertTLS(cert, domains, now)
return validCertTLS(cert, domains, locked, now)
}

func validCertTLS(cert *tls.Certificate, domains []DomainName, now time.Time) (validCert *tls.Certificate, err error) {
func validCertTLS(cert *tls.Certificate, domains []DomainName, locked bool, now time.Time) (validCert *tls.Certificate, err error) {
if cert == nil {
return nil, errors.New("certificate is nil")
}
Expand All @@ -88,6 +88,10 @@ func validCertTLS(cert *tls.Certificate, domains []DomainName, now time.Time) (v
return nil, errors.New("certificate has no public key")
}

if locked {
return cert, nil
}

// ensure the leaf corresponds to the private key and matches the certKey type
switch pub := cert.Leaf.PublicKey.(type) {
case *rsa.PublicKey:
Expand Down
42 changes: 34 additions & 8 deletions internal/cert_manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (resultCert *tls.Ce
if cert != nil {
logger.Debug("Got certificate from local state", log.Cert(cert))

// TODO: Disable check for locked certificates https://github.com/rekby/lets-proxy2/issues/48
cert, err = validCertDer([]DomainName{needDomain}, cert.Certificate, cert.PrivateKey, now)
cert, err = validCertDer([]DomainName{needDomain}, cert.Certificate, cert.PrivateKey, certState.GetLocked(), now)
logger.Debug("Validate certificate from local state", zap.Error(err))
if err == nil {
return cert, nil
Expand All @@ -144,24 +143,29 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (resultCert *tls.Ce
logger.Debug("Can't get certificate from local state", zap.Error(err))
}

locked, err := isCertLocked(ctx, m.Cache, certName)
log.DebugDPanic(logger, err, "Check if certificate locked")

cert, err = getCertificate(ctx, m.Cache, certName, keyRSA)
if err == nil {
logger.Debug("Certificate loaded from cache")

// TODO: Disable check for locked certificates https://github.com/rekby/lets-proxy2/issues/48
cert, err = validCertDer([]DomainName{needDomain}, cert.Certificate, cert.PrivateKey, now)
cert, err = validCertDer([]DomainName{needDomain}, cert.Certificate, cert.PrivateKey, locked, now)
logger.Debug("Check if certificate ok", zap.Error(err))
if err == nil {
certState.CertSet(ctx, cert)
certState.CertSet(ctx, locked, cert)
return cert, nil
}
}

if locked {
return nil, errHaveNoCert
}

// TODO: check domain
certIssueContext, cancelFunc := context.WithTimeout(ctx, m.CertificateIssueTimeout)
defer cancelFunc()

// TODO: receive cert for domain and subdomains same time
res, err := m.createCertificateForDomains(certIssueContext, certName, domainNamesFromCertificateName(certName),
needDomain)
if err == nil {
Expand Down Expand Up @@ -376,7 +380,7 @@ func (m *Manager) issueCertificate(ctx context.Context, certName certNameType, d
return nil, err
}

cert, err := validCertDer(domains, der, key, time.Now())
cert, err := validCertDer(domains, der, key, false, time.Now())
log.DebugDPanic(logger, err, "Check certificate is valid")
if err == nil {
storeCertificate(ctx, m.Cache, certName, cert)
Expand Down Expand Up @@ -524,6 +528,11 @@ func storeCertificate(ctx context.Context, cache cache.Cache, certName certNameT
return
}

locked, _ := isCertLocked(ctx, cache, certName)
if locked {
logger.Panic("Logical error - try to save to locked certificate")
}

var keyType keyType

var certBuf bytes.Buffer
Expand Down Expand Up @@ -601,7 +610,11 @@ func getCertificate(ctx context.Context, cache cache.Cache, certName certNameTyp
return nil, err
}
}
return validCertTLS(&cert2, nil, time.Now())
locked, err := isCertLocked(ctx, cache, certName)
if err != nil {
return nil, err
}
return validCertTLS(&cert2, nil, locked, time.Now())
}

func getCertificateKeyBytes(ctx context.Context, cache cache.Cache, certName certNameType, keyType keyType) ([]byte, error) {
Expand Down Expand Up @@ -674,3 +687,16 @@ func isNeedRenew(cert *tls.Certificate, now time.Time) bool {
}
return cert.Leaf.NotAfter.Add(-time.Hour * 24 * 30).Before(now)
}

func isCertLocked(ctx context.Context, storage cache.Cache, certName certNameType) (bool, error) {
lockName := certName.String() + ".lock"
_, err := storage.Get(ctx, lockName)
switch err {
case cache.ErrCacheMiss:
return false, nil
case nil:
return true, nil
default:
return false, err
}
}
47 changes: 32 additions & 15 deletions internal/cert_manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,26 @@ func (c contextConnection) GetContext() context.Context {

//nolint:gochecknoinits
func init() {
logger, err := zap.NewDevelopment()
if err != nil {
panic(err)
}
zc.SetDefaultLogger(logger)
zc.SetDefaultLogger(zap.NewNop())
}

func TestManager_GetCertificateTls(t *testing.T) {
logger, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}

ctx, flush := th.TestContext()
defer flush()

logger := zc.L(ctx)

mc := minimock.NewController(t)
defer mc.Finish()

cacheMock := NewCacheMock(mc)
cacheMock.GetMock.Set(func(ctx context.Context, key string) (ba1 []byte, err error) {
zc.L(ctx).Debug("Cache mock get", zap.String("key", key))

if key == "locked.ru.lock" {
return []byte{}, nil
}

return nil, cache.ErrCacheMiss
})
cacheMock.PutMock.Set(func(ctx context.Context, key string, data []byte) (err error) {
Expand Down Expand Up @@ -137,6 +135,14 @@ func TestManager_GetCertificateTls(t *testing.T) {
}
})

t.Run("Locked", func(t *testing.T) {
domain := "locked.ru"

cert, err := manager.GetCertificate(&tls.ClientHelloInfo{ServerName: domain, Conn: contextConnection{Context: ctx}})
td.CmpError(t, err)
td.CmpNil(t, cert)
})

t.Run("punycode-domain", func(t *testing.T) {
domain := "xn--80adjurfhd.xn--p1ai" // проверка.рф

Expand Down Expand Up @@ -216,20 +222,22 @@ func TestManager_GetCertificateTls(t *testing.T) {
}

func TestManager_GetCertificateHttp01(t *testing.T) {
logger, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}

ctx, flush := th.TestContext()
defer flush()

logger := zc.L(ctx)

mc := minimock.NewController(t)
defer mc.Finish()

cacheMock := NewCacheMock(mc)
cacheMock.GetMock.Set(func(ctx context.Context, key string) (ba1 []byte, err error) {
zc.L(ctx).Debug("Cache mock get", zap.String("key", key))

if key == "locked.ru.lock" {
return []byte{}, nil
}

return nil, cache.ErrCacheMiss
})
cacheMock.PutMock.Set(func(ctx context.Context, key string, data []byte) (err error) {
Expand Down Expand Up @@ -293,6 +301,14 @@ func TestManager_GetCertificateHttp01(t *testing.T) {
}
})

t.Run("Locked", func(t *testing.T) {
domain := "locked.ru"

cert, err := manager.GetCertificate(&tls.ClientHelloInfo{ServerName: domain, Conn: contextConnection{Context: ctx}})
td.CmpError(t, err)
td.CmpNil(t, cert)
})

t.Run("punycode-domain", func(t *testing.T) {
domain := "xn--80adjurfhd.xn--p1ai" // проверка.рф

Expand Down Expand Up @@ -419,6 +435,7 @@ func TestStoreCertificate(t *testing.T) {
fmt.Println(string(data))
return nil
})
cacheMock.GetMock.Return(nil, cache.ErrCacheMiss)

storeCertificate(ctx, cacheMock, "asd", cert)
}
Expand Down
12 changes: 1 addition & 11 deletions internal/th/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,16 @@ package th

import (
"context"
"fmt"
"time"

zc "github.com/rekby/zapcontext"

"go.uber.org/zap"
)

func TestContext() (ctx context.Context, flush func()) {
logger, err := zap.NewDevelopment()
if err != nil {
fmt.Print()
panic(err)
}

ctx, cancel := context.WithCancel(zc.WithLogger(context.Background(), logger))
ctx, cancel := context.WithCancel(zc.WithLogger(context.Background(), zap.NewNop()))
flush = func() {
cancel()
time.Sleep(time.Millisecond)
_ = logger.Sync()
}
return ctx, flush
}
Expand Down

0 comments on commit fa915b2

Please sign in to comment.