Skip to content

Commit

Permalink
Require secret to reconnect as provider (#1)
Browse files Browse the repository at this point in the history
* Require secret to reconnect as provider

* Disconnect websocket on error

---------

Co-authored-by: Brad Murray <[email protected]>
  • Loading branch information
tulir and bradtgmurray authored Dec 19, 2023
1 parent 3a55c6f commit b1dd8ef
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 28 deletions.
11 changes: 11 additions & 0 deletions cmd/registration_relay/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"encoding/base64"
"flag"
"os"
"os/signal"
Expand Down Expand Up @@ -28,6 +29,11 @@ func main() {
flagenv.StringEnvWithDefault("REGISTRATION_RELAY_LISTEN", ":8000"),
"Listen address",
)
secret := flag.String(
"secret",
flagenv.StringEnvWithDefault("REGISTRATION_RELAY_SECRET", ""),
"Secret (32 bytes encoded as base64)",
)
metricsListenAddr := flag.String(
"metricsListen",
flagenv.StringEnvWithDefault("REGISTRATION_RELAY_METRICS_LISTEN", ":5000"),
Expand All @@ -54,6 +60,11 @@ func main() {

cfg := config.Config{}
cfg.API.Listen = *listenAddr
var err error
cfg.Secret, err = base64.StdEncoding.DecodeString(*secret)
if err != nil || len(cfg.Secret) != 32 {
log.Fatal().Err(err).Int("secret_len", len(cfg.Secret)).Msg("Invalid secret")
}
cfg.API.ValidateAuthURL = *validateAuthURL

log.Info().Str("commit", Commit).Str("build_time", BuildTime).Msg("registration-relay starting")
Expand Down
4 changes: 3 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
type api struct {
log zerolog.Logger
server *http.Server
secret []byte
}

func NewAPI(cfg config.Config) *api {
Expand All @@ -29,7 +30,8 @@ func NewAPI(cfg config.Config) *api {
Logger()

api := api{
log: logger,
log: logger,
secret: cfg.Secret,
}

r := chi.NewRouter()
Expand Down
2 changes: 1 addition & 1 deletion internal/api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (a *api) providerWebsocket(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()

provider := provider.NewProvider(conn)
provider := provider.NewProvider(conn, a.secret)
provider.WebsocketLoop()

a.log.Info().Msg("Websocket connection closed")
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

type Config struct {
Version string
Secret []byte
API struct {
Listen string
ValidateAuthURL string
Expand Down
71 changes: 46 additions & 25 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package provider

import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"sync"

"github.com/gorilla/websocket"
Expand Down Expand Up @@ -31,27 +35,37 @@ func GetProvider(code string) (*provider, bool) {
return p, exists
}

func RegisterProvider(code string, provider *provider) (string, error) {
func calculateSecret(globalSecret []byte, code string) string {
h := hmac.New(sha256.New, globalSecret)
h.Write([]byte(code))
return base64.RawStdEncoding.EncodeToString(h.Sum(nil))
}

func RegisterProvider(data registerCommandData, provider *provider) (*registerCommandData, error) {
codeToProviderLock.Lock()
defer codeToProviderLock.Unlock()

if existing, exists := codeToProvider[code]; exists {
existing.log.Warn().
Str("code", code).
Msg("New provider with same code registering, exiting websocket")
existing.ws.Close()
}

if code == "" {
if data.Code == "" {
var err error
code, err = util.GenerateProviderCode()
data.Code, err = util.GenerateProviderCode()
if err != nil {
return "", err
return nil, err
}
data.Secret = calculateSecret(provider.globalSecret, data.Code)
} else {
if calculateSecret(provider.globalSecret, data.Code) != data.Secret {
return nil, fmt.Errorf("invalid secret")
}
if existing, exists := codeToProvider[data.Code]; exists {
existing.log.Warn().
Str("code", data.Code).
Msg("New provider with same code registering, exiting websocket")
existing.ws.Close()
}
}

codeToProvider[code] = provider
return code, nil
codeToProvider[data.Code] = provider
return &data, nil
}

func UnregisterProvider(key string) {
Expand All @@ -70,18 +84,21 @@ type provider struct {
ws *websocket.Conn
resultsCh chan json.RawMessage
reqID int

globalSecret []byte
}

func NewProvider(ws *websocket.Conn) *provider {
func NewProvider(ws *websocket.Conn, secret []byte) *provider {
logger := log.With().
Str("component", "provider").
Logger()

return &provider{
log: logger,
ws: ws,
resultsCh: make(chan json.RawMessage),
reqID: 1,
log: logger,
ws: ws,
resultsCh: make(chan json.RawMessage),
reqID: 1,
globalSecret: secret,
}
}

Expand All @@ -91,6 +108,7 @@ func (p *provider) WebsocketLoop() {

registerCode := ""

Loop:
for {
_, message, err := p.ws.ReadMessage()
if err != nil {
Expand All @@ -110,21 +128,24 @@ func (p *provider) WebsocketLoop() {
var request registerCommandData
if err := json.Unmarshal(rawCommand.Data, &request); err != nil {
p.log.Err(err).Msg("Failed to decode register request")
break
break Loop
}
registerCode, err = RegisterProvider(request.Code, p)
response, err := RegisterProvider(request, p)
if err != nil {
p.log.Err(err).Msg("Failed to register provider")
break
buf, err := json.Marshal(RawCommand[errorData]{Command: "response", Data: errorData{"invalid token"}, ReqID: rawCommand.ReqID})
if err == nil {
p.ws.WriteMessage(websocket.TextMessage, buf)
}
break Loop
}
p.log.Debug().Msg("Registered provider")

// Send back register response before setting the flag (ws is single writer)
response := registerCommandData{registerCode}
buf, err := json.Marshal(RawCommand[registerCommandData]{Command: "response", Data: response, ReqID: rawCommand.ReqID})
buf, err := json.Marshal(RawCommand[registerCommandData]{Command: "response", Data: *response, ReqID: rawCommand.ReqID})
if err != nil {
p.log.Err(err).Msg("Failed to encode register response")
break
break Loop
}
p.ws.WriteMessage(websocket.TextMessage, buf)

Expand All @@ -134,7 +155,7 @@ func (p *provider) WebsocketLoop() {
buf, err := json.Marshal(RawCommand[struct{}]{Command: "pong", ReqID: rawCommand.ReqID})
if err != nil {
p.log.Err(err).Msg("Failed to encode ping response")
break
break Loop
}
p.ws.WriteMessage(websocket.TextMessage, buf)
case "response":
Expand Down
7 changes: 6 additions & 1 deletion internal/provider/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,10 @@ type RawCommand[T any] struct {
}

type registerCommandData struct {
Code string `json:"code"`
Code string `json:"code"`
Secret string `json:"secret"`
}

type errorData struct {
Error string `json:"error"`
}

0 comments on commit b1dd8ef

Please sign in to comment.