mux/middleware/cors.go

323 lines
8.8 KiB
Go

package middleware
import (
"net/http"
"strconv"
"strings"
)
// copied from https://github.com/gorilla/handlers/blob/main/cors.go
// with some editing to fit in
//
// CORSOption represents a functional option for configuring the CORS middleware.
type CORSOption struct {
// AllowedOrigins list, including "*" will allow all
AllowedOrigins []string
// AllowedHeaders are a list of headers clients are allowed to use with.
// default: []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
AllowedHeaders []string
// AllowedMethods are a list of methods clients are allowed to use.
//
// default: []string{"HEAD", "GET", "POST"}
AllowedMethods []string
ExposedHeaders []string
// MaxAge in seconds, max allowed value is 600
MaxAge uint
AllowCredentials bool
}
type cors struct {
h http.Handler
allowedHeaders []string
allowedMethods []string
allowedOrigins []string
allowedOriginValidator OriginValidator
exposedHeaders []string
maxAge uint
ignoreOptions bool
allowCredentials bool
optionStatusCode int
}
// OriginValidator takes an origin string and returns whether that origin is allowed.
type OriginValidator func(string) bool
var (
defaultCorsOptionStatusCode = http.StatusOK
defaultCorsMethods = []string{http.MethodHead, http.MethodGet, http.MethodPost}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Type", "Content-Language", "Origin"}
// (WebKit/Safari v9 sends the Origin header by default in AJAX requests).
)
const (
corsOptionMethod string = http.MethodOptions
corsAllowOriginHeader string = "Access-Control-Allow-Origin"
corsExposeHeadersHeader string = "Access-Control-Expose-Headers"
corsMaxAgeHeader string = "Access-Control-Max-Age"
corsAllowMethodsHeader string = "Access-Control-Allow-Methods"
corsAllowHeadersHeader string = "Access-Control-Allow-Headers"
corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
corsRequestMethodHeader string = "Access-Control-Request-Method"
corsRequestHeadersHeader string = "Access-Control-Request-Headers"
corsOriginHeader string = "Origin"
corsVaryHeader string = "Vary"
corsOriginMatchAll string = "*"
)
// CORS provides Cross-Origin Resource Sharing middleware.
// Example:
//
// import (
// "net/http"
// "gitserver.in/patialtech/mux/middleware"
// "gitserver.in/patialtech/mux"
// )
//
// func main() {
// r := mux.NewRouter()
// r.Use(middleware.CORS(middleware.CORSOption{
// AllowedOrigins: []string{"*"},
// MaxAge: 60,
// }))
//
// r.Get("/", func(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("hello there"))
// })
//
// r.Serve(func(srv *http.Server) error {
// srv.Addr = ":3001"
// slog.Info("listening on http://localhost" + srv.Addr)
// return srv.ListenAndServe()
// })
// }
func CORS(opts CORSOption) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
ch := &cors{
h: h,
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
optionStatusCode: defaultCorsOptionStatusCode,
}
ch.setAllowedOrigins(opts.AllowedOrigins)
ch.setAllowedHeaders(opts.AllowedHeaders)
ch.setAllowedMethods(opts.AllowedMethods)
ch.setExposedHeaders(opts.ExposedHeaders)
ch.setMaxAge(opts.MaxAge)
ch.maxAge = opts.MaxAge
ch.allowCredentials = opts.AllowCredentials
return ch
}
}
func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get(corsOriginHeader)
if !ch.isOriginAllowed(origin) {
if r.Method != corsOptionMethod || ch.ignoreOptions {
ch.h.ServeHTTP(w, r)
}
return
}
if r.Method == corsOptionMethod {
if ch.ignoreOptions {
ch.h.ServeHTTP(w, r)
return
}
if _, ok := r.Header[corsRequestMethodHeader]; !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
method := r.Header.Get(corsRequestMethodHeader)
if !ch.isMatch(method, ch.allowedMethods) {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
var allowedHeaders []string
for _, v := range requestHeaders {
canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
continue
}
if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
w.WriteHeader(http.StatusForbidden)
return
}
allowedHeaders = append(allowedHeaders, canonicalHeader)
}
if len(allowedHeaders) > 0 {
w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
}
if ch.maxAge > 0 {
w.Header().Set(corsMaxAgeHeader, strconv.Itoa(int(ch.maxAge)))
}
if !ch.isMatch(method, defaultCorsMethods) {
w.Header().Set(corsAllowMethodsHeader, method)
}
} else if len(ch.exposedHeaders) > 0 {
w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
}
if ch.allowCredentials {
w.Header().Set(corsAllowCredentialsHeader, "true")
}
if len(ch.allowedOrigins) > 1 {
w.Header().Set(corsVaryHeader, corsOriginHeader)
}
returnOrigin := origin
if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
returnOrigin = "*"
} else {
for _, o := range ch.allowedOrigins {
// A configuration of * is different from explicitly setting an allowed
// origin. Returning arbitrary origin headers in an access control allow
// origin header is unsafe and is not required by any use case.
if o == corsOriginMatchAll {
returnOrigin = "*"
break
}
}
}
w.Header().Set(corsAllowOriginHeader, returnOrigin)
if r.Method == corsOptionMethod {
w.WriteHeader(ch.optionStatusCode)
return
}
ch.h.ServeHTTP(w, r)
}
// AllowedOrigins sets the allowed origins for CORS requests, as used in the
// 'Allow-Access-Control-Origin' HTTP header.
// Note: Passing in a []string{"*"} will allow any domain.
func (ch *cors) setAllowedOrigins(origins []string) {
// look for "*"
for _, v := range origins {
if v == corsOriginMatchAll {
ch.allowedOrigins = []string{corsOriginMatchAll}
return
}
}
ch.allowedOrigins = origins
}
// setAllowedHeaders adds the provided headers to the list of allowed headers in a
// CORS request.
// This is an appended operation, so the headers Accept, Accept-Language,
// and Content-Language are always allowed.
// Content-Type must be explicitly declared if accepting Content-Types other than
// application/x-www-form-urlencoded, multipart/form-data, or text/plain.
func (ch *cors) setAllowedHeaders(headers []string) {
for _, v := range headers {
normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
if normalizedHeader == "" {
continue
}
if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
}
}
}
// AllowedMethods can be used to explicitly allow methods in the
// Access-Control-Allow-Methods header.
// This is a replacement operation, so you must also
// pass GET, HEAD, and POST if you wish to support those methods.
func (ch *cors) setAllowedMethods(methods []string) {
if len(methods) == 0 {
return
}
ch.allowedMethods = []string{}
for _, v := range methods {
normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
if normalizedMethod == "" {
continue
}
if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
}
}
}
// ExposedHeaders can be used to specify headers that are available
// and will not be stripped out by the user-agent.
func (ch *cors) setExposedHeaders(headers []string) {
ch.exposedHeaders = []string{}
for _, v := range headers {
normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
if normalizedHeader == "" {
continue
}
if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
}
}
}
// MaxAge determines the maximum age (in seconds) between preflight requests. A
// maximum of 10 minutes is allowed. An age above this value will default to 10
// minutes.
func (ch *cors) setMaxAge(age uint) {
// Maximum of 10 minutes.
if age > 600 {
age = 600
}
ch.maxAge = age
}
func (ch *cors) isOriginAllowed(origin string) bool {
if origin == "" {
return false
}
if ch.allowedOriginValidator != nil {
return ch.allowedOriginValidator(origin)
}
if len(ch.allowedOrigins) == 0 {
return true
}
for _, allowedOrigin := range ch.allowedOrigins {
if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
return true
}
}
return false
}
func (ch *cors) isMatch(needle string, haystack []string) bool {
for _, v := range haystack {
if v == needle {
return true
}
}
return false
}