Skip to content

Commit

Permalink
Add optional authentication to bridge command requests (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fizzadar authored Dec 19, 2023
1 parent eaa4ac7 commit abd50b9
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
7 changes: 7 additions & 0 deletions cmd/registration_relay/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func main() {
"Metrics listen address",
)

validateAuthURL := flag.String(
"validateAuthURL",
flagenv.StringEnvWithDefault("REGISTRATION_RELAY_VALIDATE_AUTH_URL", ""),
"Validate auth header URL",
)

flag.Parse()

if *prettyLogs {
Expand All @@ -48,6 +54,7 @@ func main() {

cfg := config.Config{}
cfg.API.Listen = *listenAddr
cfg.API.ValidateAuthURL = *validateAuthURL

log.Info().Str("commit", Commit).Str("build_time", BuildTime).Msg("registration-relay starting")

Expand Down
10 changes: 9 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ func NewAPI(cfg config.Config) *api {
r.Get("/health", health.Health)

r.Get("/api/v1/provider", api.providerWebsocket)
r.Post("/api/v1/bridge/{command}", api.bridgeExecuteCommand)

commandHandler := api.bridgeExecuteCommand
if cfg.API.ValidateAuthURL != "" {
commandHandler = api.requireAuthHandler(
cfg.API.ValidateAuthURL,
commandHandler,
)
}
r.Post("/api/v1/bridge/{command}", commandHandler)

api.server = &http.Server{Addr: cfg.API.Listen, Handler: r}

Expand Down
76 changes: 76 additions & 0 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package api

import (
"encoding/json"
"net/http"

"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
)

var httpClient = &http.Client{}

type authResp struct {
Identifier string `json:"identifier"`
}

func (a *api) requireAuthHandler(
validateURL string,
next http.HandlerFunc,
) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authToken := r.Header.Get("Authorization")

if authToken == "" {
a.log.Warn().Msg("Request missing auth header")
w.WriteHeader(http.StatusUnauthorized)
return
}

req, err := http.NewRequest(http.MethodGet, validateURL, nil)
if err != nil {
a.log.Err(err).Msg("Failed to create request to auth validation URL")
w.WriteHeader(http.StatusInternalServerError)
return
}

req.Header.Add("Authorization", authToken)

resp, err := httpClient.Do(req)
if err != nil {
a.log.Err(err).Msg("Failed to make request to auth validation URL")
w.WriteHeader(http.StatusInternalServerError)
return
}
defer resp.Body.Close()

if resp.StatusCode >= 500 {
a.log.Error().
Int("status_code", resp.StatusCode).
Msg("Unexpected status from auth URL")
w.WriteHeader(http.StatusInternalServerError)
return
}

if resp.StatusCode != 200 {
a.log.Warn().
Int("status_code", resp.StatusCode).
Msg("Unauthorized status from auth URL")
w.WriteHeader(http.StatusUnauthorized)
return
}

var response authResp
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
a.log.Err(err).Msg("Failed to decode auth response")
w.WriteHeader(http.StatusInternalServerError)
return
}

hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("identifier", response.Identifier)
})

next(w, r)
}
}
3 changes: 2 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
type Config struct {
Version string
API struct {
Listen string
Listen string
ValidateAuthURL string
}
}

0 comments on commit abd50b9

Please sign in to comment.