From b1dd8efc9c675234524251672a5fa6b2433b7d3d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 20 Dec 2023 00:14:44 +0200 Subject: [PATCH] Require secret to reconnect as provider (#1) * Require secret to reconnect as provider * Disconnect websocket on error --------- Co-authored-by: Brad Murray --- cmd/registration_relay/main.go | 11 ++++++ internal/api/api.go | 4 +- internal/api/routes.go | 2 +- internal/config/config.go | 1 + internal/provider/provider.go | 71 ++++++++++++++++++++++------------ internal/provider/types.go | 7 +++- 6 files changed, 68 insertions(+), 28 deletions(-) diff --git a/cmd/registration_relay/main.go b/cmd/registration_relay/main.go index 2e9d742..338d4af 100644 --- a/cmd/registration_relay/main.go +++ b/cmd/registration_relay/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/base64" "flag" "os" "os/signal" @@ -28,6 +29,11 @@ func main() { flagenv.StringEnvWithDefault("REGISTRATION_RELAY_LISTEN", ":8000"), "Listen address", ) + secret := flag.String( + "secret", + flagenv.StringEnvWithDefault("REGISTRATION_RELAY_SECRET", ""), + "Secret (32 bytes encoded as base64)", + ) metricsListenAddr := flag.String( "metricsListen", flagenv.StringEnvWithDefault("REGISTRATION_RELAY_METRICS_LISTEN", ":5000"), @@ -54,6 +60,11 @@ func main() { cfg := config.Config{} cfg.API.Listen = *listenAddr + var err error + cfg.Secret, err = base64.StdEncoding.DecodeString(*secret) + if err != nil || len(cfg.Secret) != 32 { + log.Fatal().Err(err).Int("secret_len", len(cfg.Secret)).Msg("Invalid secret") + } cfg.API.ValidateAuthURL = *validateAuthURL log.Info().Str("commit", Commit).Str("build_time", BuildTime).Msg("registration-relay starting") diff --git a/internal/api/api.go b/internal/api/api.go index a556e2c..ae07f65 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -21,6 +21,7 @@ import ( type api struct { log zerolog.Logger server *http.Server + secret []byte } func NewAPI(cfg config.Config) *api { @@ -29,7 +30,8 @@ func NewAPI(cfg config.Config) *api { Logger() api := api{ - log: logger, + log: logger, + secret: cfg.Secret, } r := chi.NewRouter() diff --git a/internal/api/routes.go b/internal/api/routes.go index 13dd7ad..308e71a 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -47,7 +47,7 @@ func (a *api) providerWebsocket(w http.ResponseWriter, r *http.Request) { } defer conn.Close() - provider := provider.NewProvider(conn) + provider := provider.NewProvider(conn, a.secret) provider.WebsocketLoop() a.log.Info().Msg("Websocket connection closed") diff --git a/internal/config/config.go b/internal/config/config.go index 46cfa56..ed34310 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config type Config struct { Version string + Secret []byte API struct { Listen string ValidateAuthURL string diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 8d95ccd..11d1fe1 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1,7 +1,11 @@ package provider import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "encoding/json" + "fmt" "sync" "github.com/gorilla/websocket" @@ -31,27 +35,37 @@ func GetProvider(code string) (*provider, bool) { return p, exists } -func RegisterProvider(code string, provider *provider) (string, error) { +func calculateSecret(globalSecret []byte, code string) string { + h := hmac.New(sha256.New, globalSecret) + h.Write([]byte(code)) + return base64.RawStdEncoding.EncodeToString(h.Sum(nil)) +} + +func RegisterProvider(data registerCommandData, provider *provider) (*registerCommandData, error) { codeToProviderLock.Lock() defer codeToProviderLock.Unlock() - if existing, exists := codeToProvider[code]; exists { - existing.log.Warn(). - Str("code", code). - Msg("New provider with same code registering, exiting websocket") - existing.ws.Close() - } - - if code == "" { + if data.Code == "" { var err error - code, err = util.GenerateProviderCode() + data.Code, err = util.GenerateProviderCode() if err != nil { - return "", err + return nil, err + } + data.Secret = calculateSecret(provider.globalSecret, data.Code) + } else { + if calculateSecret(provider.globalSecret, data.Code) != data.Secret { + return nil, fmt.Errorf("invalid secret") + } + if existing, exists := codeToProvider[data.Code]; exists { + existing.log.Warn(). + Str("code", data.Code). + Msg("New provider with same code registering, exiting websocket") + existing.ws.Close() } } - codeToProvider[code] = provider - return code, nil + codeToProvider[data.Code] = provider + return &data, nil } func UnregisterProvider(key string) { @@ -70,18 +84,21 @@ type provider struct { ws *websocket.Conn resultsCh chan json.RawMessage reqID int + + globalSecret []byte } -func NewProvider(ws *websocket.Conn) *provider { +func NewProvider(ws *websocket.Conn, secret []byte) *provider { logger := log.With(). Str("component", "provider"). Logger() return &provider{ - log: logger, - ws: ws, - resultsCh: make(chan json.RawMessage), - reqID: 1, + log: logger, + ws: ws, + resultsCh: make(chan json.RawMessage), + reqID: 1, + globalSecret: secret, } } @@ -91,6 +108,7 @@ func (p *provider) WebsocketLoop() { registerCode := "" +Loop: for { _, message, err := p.ws.ReadMessage() if err != nil { @@ -110,21 +128,24 @@ func (p *provider) WebsocketLoop() { var request registerCommandData if err := json.Unmarshal(rawCommand.Data, &request); err != nil { p.log.Err(err).Msg("Failed to decode register request") - break + break Loop } - registerCode, err = RegisterProvider(request.Code, p) + response, err := RegisterProvider(request, p) if err != nil { p.log.Err(err).Msg("Failed to register provider") - break + buf, err := json.Marshal(RawCommand[errorData]{Command: "response", Data: errorData{"invalid token"}, ReqID: rawCommand.ReqID}) + if err == nil { + p.ws.WriteMessage(websocket.TextMessage, buf) + } + break Loop } p.log.Debug().Msg("Registered provider") // Send back register response before setting the flag (ws is single writer) - response := registerCommandData{registerCode} - buf, err := json.Marshal(RawCommand[registerCommandData]{Command: "response", Data: response, ReqID: rawCommand.ReqID}) + buf, err := json.Marshal(RawCommand[registerCommandData]{Command: "response", Data: *response, ReqID: rawCommand.ReqID}) if err != nil { p.log.Err(err).Msg("Failed to encode register response") - break + break Loop } p.ws.WriteMessage(websocket.TextMessage, buf) @@ -134,7 +155,7 @@ func (p *provider) WebsocketLoop() { buf, err := json.Marshal(RawCommand[struct{}]{Command: "pong", ReqID: rawCommand.ReqID}) if err != nil { p.log.Err(err).Msg("Failed to encode ping response") - break + break Loop } p.ws.WriteMessage(websocket.TextMessage, buf) case "response": diff --git a/internal/provider/types.go b/internal/provider/types.go index 1c75e49..6c9fad8 100644 --- a/internal/provider/types.go +++ b/internal/provider/types.go @@ -7,5 +7,10 @@ type RawCommand[T any] struct { } type registerCommandData struct { - Code string `json:"code"` + Code string `json:"code"` + Secret string `json:"secret"` +} + +type errorData struct { + Error string `json:"error"` }