Skip to content

Commit

Permalink
feat(core): allow custom cors setups (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayuhito authored Jun 13, 2024
1 parent e72b4ce commit 5842e53
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 11 deletions.
2 changes: 1 addition & 1 deletion core/Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ tasks:
dev:
deps: [generate]
cmds:
- go run -tags "no_duckdb_arrow" ./cmd/ {{.CLI_ARGS}} -logger=pretty -level=debug
- go run -tags "no_duckdb_arrow" ./cmd/ {{.CLI_ARGS}} -logger=pretty -level=debug -corsorigins=http://localhost:8080,http://localhost:5173
env:
CGO_ENABLED: "1"

Expand Down
3 changes: 3 additions & 0 deletions core/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type ServerConfig struct {
// Cache settings
CacheCleanupInterval time.Duration

// CORS Settings
CORSAllowedOrigins []string `env:"CORS_ALLOWED_ORIGINS" envSeparator:","`

// Logging settings
Logger string `env:"LOGGER"`
Level string `env:"LOGGER_LEVEL"`
Expand Down
20 changes: 10 additions & 10 deletions core/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/medama-io/medama/util"
"github.com/medama-io/medama/util/logger"
"github.com/ogen-go/ogen/middleware"
"github.com/rs/cors"
)

type StartCommand struct {
Expand Down Expand Up @@ -64,12 +63,19 @@ func (s *StartCommand) ParseFlags(args []string) error {
fs.StringVar(&s.Server.Level, "level", DefaultLoggerLevel, "Logger level (debug, info, warn, error)")
fs.Int64Var(&s.Server.Port, "port", DefaultPort, "Port to listen on")

// Handle array type flags
allowedOrigins := fs.String("corsorigins", "", "Comma separated list of allowed origins on API routes")

// Parse flags
err := fs.Parse(args)
if err != nil {
return errors.Wrap(err, "failed to parse flags")
}

if *allowedOrigins != "" {
s.Server.CORSAllowedOrigins = strings.Split(*allowedOrigins, ",")
}

return nil
}

Expand All @@ -80,6 +86,7 @@ func (s *StartCommand) Run(ctx context.Context) error {
return errors.Wrap(err, "failed to setup logger")
}
log.Info().Msg(GetVersion())
log.Debug().Interface("config", s).Msg("")

// Setup database
sqlite, err := sqlite.NewClient(s.AppDB.Host)
Expand Down Expand Up @@ -178,15 +185,8 @@ func (s *StartCommand) Run(ctx context.Context) error {
}
}))

// Apply CORS headers.
cors := cors.New(cors.Options{
// TODO: Allow for configurable allowed origins. Typically this won't be needed
// as the client will be served from the same domain. But it is useful for development
// and external dashboards.
AllowedOrigins: []string{"http://localhost:8080", "http://localhost:5173"},
AllowCredentials: true,
})
handler := cors.Handler(mux)
// Apply custom CORS middleware to the mux handler
handler := middlewares.CORSAllowedOriginsMiddleware(s.Server.CORSAllowedOrigins)(mux)

srv := &http.Server{
Addr: ":" + strconv.FormatInt(s.Server.Port, 10),
Expand Down
41 changes: 41 additions & 0 deletions core/middlewares/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package middlewares

import (
"net/http"
"path"
"strings"

"github.com/medama-io/medama/util/logger"
"github.com/rs/cors"
)

// CORSAllowedOriginsMiddleware creates a middleware to apply CORS headers based on the allowed origins.
// Typically this won't need a custom list of allowed origins as the client will be served from the same domain.
// But it is useful for development and external dashboards as we need to pass credentials from different domains.
func CORSAllowedOriginsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler {
// Create a CORS handler with custom options for the allowed origins
customCORS := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowCredentials: true,
Debug: true,
})

// Create a default CORS handler
defaultCORS := cors.Default()

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uPath := path.Clean(r.URL.Path)
log := logger.Get()

if allowedOrigins != nil && strings.HasPrefix(uPath, "/api") && !strings.HasPrefix(uPath, "/api/event") {
// Apply modified CORS headers for API routes.
log.Debug().Str("allowed_origins", strings.Join(allowedOrigins, ",")).Str("path", uPath).Msg("Applying custom CORS")
customCORS.Handler(next).ServeHTTP(w, r)
} else {
// Apply default CORS headers
defaultCORS.Handler(next).ServeHTTP(w, r)
}
})
}
}

0 comments on commit 5842e53

Please sign in to comment.