Skip to content

Commit

Permalink
Watch TLS cert file changes to update certs when needed. (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
wi1dcard authored Aug 23, 2024
1 parent cb0ae8e commit 48f9e55
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 17 deletions.
41 changes: 33 additions & 8 deletions fingerproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/wi1dcard/fingerproxy/pkg/certwatcher"
"github.com/wi1dcard/fingerproxy/pkg/debug"
"github.com/wi1dcard/fingerproxy/pkg/fingerprint"
"github.com/wi1dcard/fingerproxy/pkg/proxyserver"
Expand All @@ -36,6 +37,7 @@ var (
PrometheusLog = log.New(os.Stderr, "[metrics] ", logFlags)
ReverseProxyLog = log.New(os.Stderr, "[reverseproxy] ", logFlags)
FingerprintLog = log.New(os.Stderr, "[fingerprint] ", logFlags)
CertWatcherLog = log.New(os.Stderr, "[certwatcher] ", logFlags)
DefaultLog = log.New(os.Stderr, "[fingerproxy] ", logFlags)

// The Prometheus metric registry used by fingerproxy
Expand Down Expand Up @@ -90,8 +92,7 @@ func defaultReverseProxyHTTPHandler(forwardTo *url.URL, headerInjectors []revers
return handler
}

func defaultProxyServer(handler http.Handler, tlsConfig *tls.Config) *proxyserver.Server {
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
func defaultProxyServer(ctx context.Context, handler http.Handler, tlsConfig *tls.Config) *proxyserver.Server {
svr := proxyserver.NewServer(ctx, handler, tlsConfig)

svr.VerboseLogs = *flagVerboseLogs
Expand All @@ -108,6 +109,25 @@ func defaultProxyServer(handler http.Handler, tlsConfig *tls.Config) *proxyserve
return svr
}

func initCertWatcher() *certwatcher.CertWatcher {
certwatcher.Logger = CertWatcherLog
certwatcher.VerboseLogs = *flagVerboseLogs
cw, err := certwatcher.New(*flagCertFilename, *flagKeyFilename)
if err != nil {
DefaultLog.Fatalf(`invalid cert filename "%s" or certkey filename "%s": %s`, *flagCertFilename, *flagKeyFilename, err)
}
return cw
}

func defaultTLSConfig(cw *certwatcher.CertWatcher) *tls.Config {
return &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
GetCertificate: cw.GetCertificate,
}
}

func initFingerprint() {
fingerprint.Logger = FingerprintLog
fingerprint.VerboseLogs = *flagVerboseLogs
Expand All @@ -124,20 +144,25 @@ func Run() {
// fingerprint package
initFingerprint()

// tls cert watcher
cw := initCertWatcher()

// signal cancels context
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)

// main TLS server
server := defaultProxyServer(
ctx,
defaultReverseProxyHTTPHandler(
parseForwardURL(),
GetHeaderInjectors(),
),
&tls.Config{
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{parseTLSCerts()},
},
defaultTLSConfig(cw),
)

// start cert watcher
go cw.Start(ctx)

// metrics server
PrometheusLog.Printf("server listening on %s", *flagMetricsListenAddr)
go http.ListenAndServe(
Expand Down
9 changes: 0 additions & 9 deletions flags.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package fingerproxy

import (
"crypto/tls"
"flag"
"fmt"
"net/url"
Expand Down Expand Up @@ -141,14 +140,6 @@ func parseForwardURL() *url.URL {
return forwardURL
}

func parseTLSCerts() tls.Certificate {
tlsCert, err := tls.LoadX509KeyPair(*flagCertFilename, *flagKeyFilename)
if err != nil {
DefaultLog.Fatalf(`invalid cert filename "%s" or certkey filename "%s": %s`, *flagCertFilename, *flagKeyFilename, err)
}
return tlsCert
}

func parseDurationMetricBuckets() []float64 {
bucketStrings := strings.Split(*flagDurationMetricBuckets, ",")
buckets := []float64{}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cloudflare/circl v1.3.7 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20231212022811-ec68065c825e // indirect
github.com/klauspost/compress v1.17.4 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dreadl0ck/tlsx v1.0.1-google-gopacket h1:/P3y+CGRiCQbW0nZU2jWkEwKfXLkpEgHNhbbqlnrTTM=
github.com/dreadl0ck/tlsx v1.0.1-google-gopacket/go.mod h1:amAb73WEEgPHWniMfwro6UpN6St3e5ypgq2tXM89IOo=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
Expand Down
179 changes: 179 additions & 0 deletions pkg/certwatcher/certwatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package certwatcher

import (
"context"
"crypto/tls"
"log"
"sync"

"github.com/fsnotify/fsnotify"
)

var (
VerboseLogs bool
Logger *log.Logger
)

func logf(format string, args ...any) {
if Logger != nil {
Logger.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}

func vlogf(format string, args ...any) {
if VerboseLogs {
logf(format, args...)
}
}

// CertWatcher watches certificate and key files for changes. When either file
// changes, it reads and parses both and calls an optional callback with the new
// certificate.
type CertWatcher struct {
sync.RWMutex

currentCert *tls.Certificate
watcher *fsnotify.Watcher

certPath string
keyPath string
}

// New returns a new CertWatcher watching the given certificate and key.
func New(certPath, keyPath string) (*CertWatcher, error) {
var err error

cw := &CertWatcher{
certPath: certPath,
keyPath: keyPath,
}

// Initial read of certificate and key.
if err := cw.ReadCertificate(); err != nil {
return nil, err
}

cw.watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}

return cw, nil
}

// GetCertificate fetches the currently loaded certificate, which may be nil.
func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cw.RLock()
defer cw.RUnlock()
return cw.currentCert, nil
}

// Start starts the watch on the certificate and key files.
func (cw *CertWatcher) Start(ctx context.Context) error {
files := []string{cw.certPath, cw.keyPath}

for _, f := range files {
if err := cw.watcher.Add(f); err != nil {
logf("error watching file: %s", err)
return err
}
}

go cw.Watch()

// Block until the context is done.
<-ctx.Done()

return cw.watcher.Close()
}

// Watch reads events from the watcher's channel and reacts to changes.
func (cw *CertWatcher) Watch() {
for {
select {
case event, ok := <-cw.watcher.Events:
// Channel is closed.
if !ok {
return
}

cw.handleEvent(event)

case err, ok := <-cw.watcher.Errors:
// Channel is closed.
if !ok {
return
}

logf("certificate watch error: %s", err)
}
}
}

// ReadCertificate reads the certificate and key files from disk, parses them,
// and updates the current certificate on the watcher. If a callback is set, it
// is invoked with the new certificate.
func (cw *CertWatcher) ReadCertificate() error {
cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
if err != nil {
return err
}

cw.Lock()
cw.currentCert = &cert
cw.Unlock()

vlogf("updated current TLS certificate")

return nil
}

func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
// Only care about events which may modify the contents of the file.
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
return
}

vlogf("certificate event: %s", event)

// If the file was removed, re-add the watch.
if isRemove(event) {
if err := cw.watcher.Add(event.Name); err != nil {
logf("error re-watching file: %s", err)
}
}

if err := cw.ReadCertificate(); err != nil {
logf("error re-reading certificate: %s", err)
}
}

func isWrite(event fsnotify.Event) bool {
return event.Op&fsnotify.Write == fsnotify.Write
}

func isCreate(event fsnotify.Event) bool {
return event.Op&fsnotify.Create == fsnotify.Create
}

func isRemove(event fsnotify.Event) bool {
return event.Op&fsnotify.Remove == fsnotify.Remove
}

0 comments on commit 48f9e55

Please sign in to comment.