Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow regexp on AllowOrigins #36

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ language: go
sudo: false

go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- tip
- 1.11.x
- 1.12.x
- 1.13.x
- master

matrix:
fast_finish: true
include:
- go: 1.11.x
env: GO111MODULE=on
- go: 1.12.x
env: GO111MODULE=on

script:
- go test -v -covermode=atomic -coverprofile=coverage.out
Expand Down
62 changes: 60 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package cors

import (
"net/http"
"regexp"
"strings"

"github.com/gin-gonic/gin"
)
Expand All @@ -14,19 +16,42 @@ type cors struct {
exposeHeaders []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
}

var (
DefaultSchemas = []string{
arizz96 marked this conversation as resolved.
Show resolved Hide resolved
"http://",
"https://",
}
ExtensionSchemas = []string{
arizz96 marked this conversation as resolved.
Show resolved Hide resolved
"chrome-extension://",
"safari-extension://",
"moz-extension://",
"ms-browser-extension://",
}
FileSchemas = []string{
arizz96 marked this conversation as resolved.
Show resolved Hide resolved
"file://",
}
WebSocketSchemas = []string{
arizz96 marked this conversation as resolved.
Show resolved Hide resolved
"ws://",
"wss://",
}
)

func newCors(config Config) *cors {
if err := config.Validate(); err != nil {
panic(err.Error())
}

return &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
}
}

Expand All @@ -36,14 +61,22 @@ func (cors *cors) applyCors(c *gin.Context) {
// request is not a CORS request
return
}
host := c.Request.Header.Get("Host")

if origin == "http://"+host || origin == "https://"+host {
// request is not a CORS request but have origin header.
// for example, use fetch api
return
}

if !cors.validateOrigin(origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}

if c.Request.Method == "OPTIONS" {
cors.handlePreflight(c)
defer c.AbortWithStatus(200)
defer c.AbortWithStatus(http.StatusNoContent) // Using 204 is better than 200 when the request status is OPTIONS
} else {
cors.handleNormal(c)
}
Expand All @@ -53,15 +86,40 @@ func (cors *cors) applyCors(c *gin.Context) {
}
}

func (cors *cors) validateWildcardOrigin(origin string) bool {
for _, w := range cors.wildcardOrigins {
if w[0] == "*" && strings.HasSuffix(origin, w[1]) {
return true
}
if w[1] == "*" && strings.HasPrefix(origin, w[0]) {
return true
}
if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) {
return true
}
}

return false
}

func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
r, _ := regexp.Compile("^\\/(.+)\\/[gimuy]?$")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move regexp.Compile to the top block as a global variable for better performance.

for _, value := range cors.allowOrigins {
if value == origin {
if r.MatchString(value) {
match, _ := regexp.MatchString(r.FindStringSubmatch(value)[1], origin)
if match {
return true
}
} else if value == origin {
return true
}
}
if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
return true
}
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
}
Expand Down
96 changes: 88 additions & 8 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cors

import (
"errors"
"regexp"
"strings"
"time"

Expand All @@ -12,21 +13,21 @@ import (
type Config struct {
AllowAllOrigins bool

// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// AllowOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// Default value is []
AllowOrigins []string

// AllowOriginFunc is a custom function to validate the origin. It take the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool

// AllowedMethods is a list of methods the client is allowed to use with
// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET and POST)
AllowMethods []string

// AllowedHeaders is list of non simple headers the client is allowed to use with
// AllowHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
AllowHeaders []string

Expand All @@ -41,6 +42,18 @@ type Config struct {
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached
MaxAge time.Duration

// Allows to add origins like http://some-domain/*, https://api.* or http://some.*.subdomain.com
AllowWildcard bool

// Allows usage of popular browser extensions schemas
AllowBrowserExtensions bool

// Allows usage of WebSocket protocol
AllowWebSockets bool

// Allows usage of file:// schema (dangerous!) use it only when you 100% sure it's needed
AllowFiles bool
}

// AddAllowMethods is allowed to add custom methods
Expand All @@ -58,22 +71,89 @@ func (c *Config) AddExposeHeaders(headers ...string) {
c.ExposeHeaders = append(c.ExposeHeaders, headers...)
}

func (c Config) getAllowedSchemas() []string {
allowedSchemas := DefaultSchemas
if c.AllowBrowserExtensions {
allowedSchemas = append(allowedSchemas, ExtensionSchemas...)
}
if c.AllowWebSockets {
allowedSchemas = append(allowedSchemas, WebSocketSchemas...)
}
if c.AllowFiles {
allowedSchemas = append(allowedSchemas, FileSchemas...)
}
return allowedSchemas
}

func (c Config) validateAllowedSchemas(origin string) bool {
allowedSchemas := c.getAllowedSchemas()

r, _ := regexp.Compile("^\\/(.+)\\/[gimuy]?$")
if r.MatchString(origin) {
// Normalize regexp-based origins
origin = r.FindStringSubmatch(origin)[1]
origin = strings.Replace(origin, "?", "", 1)
}

for _, schema := range allowedSchemas {
if strings.HasPrefix(origin, schema) {
return true
}
}
return false
}

// Validate is check configuration of user defined.
func (c Config) Validate() error {
func (c *Config) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed")
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
return errors.New("bad origin: origins must either be '*' or include http:// or https://")
if origin == "*" {
c.AllowAllOrigins = true
return nil
} else if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
}
}
return nil
}

func (c Config) parseWildcardRules() [][]string {
var wRules [][]string

if !c.AllowWildcard {
return wRules
}

for _, o := range c.AllowOrigins {
if !strings.Contains(o, "*") {
continue
}

if c := strings.Count(o, "*"); c > 1 {
panic(errors.New("only one * is allowed").Error())
}

i := strings.Index(o, "*")
if i == 0 {
wRules = append(wRules, []string{"*", o[1:]})
continue
}
if i == (len(o) - 1) {
wRules = append(wRules, []string{o[:i-1], "*"})
continue
}

wRules = append(wRules, []string{o[:i], o[i+1:]})
}

return wRules
}

// DefaultConfig returns a generic default configuration mapped to localhost.
func DefaultConfig() Config {
return Config{
Expand Down
Loading