package middleware import ( "net/http" "strconv" "strings" ) // copied from https://github.com/gorilla/handlers/blob/main/cors.go // with some editing to fit in // // CORSOption represents a functional option for configuring the CORS middleware. type CORSOption struct { // AllowedOrigins list, including "*" will allow all AllowedOrigins []string // AllowedHeaders are a list of headers clients are allowed to use with. // default: []string{"Accept", "Accept-Language", "Content-Language", "Origin"} AllowedHeaders []string // AllowedMethods are a list of methods clients are allowed to use. // // default: []string{"HEAD", "GET", "POST"} AllowedMethods []string ExposedHeaders []string // MaxAge in seconds, max allowed value is 600 MaxAge uint AllowCredentials bool } type cors struct { h http.Handler allowedHeaders []string allowedMethods []string allowedOrigins []string allowedOriginValidator OriginValidator exposedHeaders []string maxAge uint ignoreOptions bool allowCredentials bool optionStatusCode int } // OriginValidator takes an origin string and returns whether that origin is allowed. type OriginValidator func(string) bool var ( defaultCorsOptionStatusCode = http.StatusOK defaultCorsMethods = []string{http.MethodHead, http.MethodGet, http.MethodPost} defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Type", "Content-Language", "Origin"} // (WebKit/Safari v9 sends the Origin header by default in AJAX requests). ) const ( corsOptionMethod string = http.MethodOptions corsAllowOriginHeader string = "Access-Control-Allow-Origin" corsExposeHeadersHeader string = "Access-Control-Expose-Headers" corsMaxAgeHeader string = "Access-Control-Max-Age" corsAllowMethodsHeader string = "Access-Control-Allow-Methods" corsAllowHeadersHeader string = "Access-Control-Allow-Headers" corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" corsRequestMethodHeader string = "Access-Control-Request-Method" corsRequestHeadersHeader string = "Access-Control-Request-Headers" corsOriginHeader string = "Origin" corsVaryHeader string = "Vary" corsOriginMatchAll string = "*" ) // CORS provides Cross-Origin Resource Sharing middleware. // Example: // // import ( // "net/http" // "gitserver.in/patialtech/mux/middleware" // "gitserver.in/patialtech/mux" // ) // // func main() { // r := mux.NewRouter() // r.Use(middleware.CORS(middleware.CORSOption{ // AllowedOrigins: []string{"*"}, // MaxAge: 60, // })) // // r.Get("/", func(w http.ResponseWriter, r *http.Request) { // w.Write([]byte("hello there")) // }) // // r.Serve(func(srv *http.Server) error { // srv.Addr = ":3001" // slog.Info("listening on http://localhost" + srv.Addr) // return srv.ListenAndServe() // }) // } func CORS(opts CORSOption) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { ch := &cors{ h: h, allowedMethods: defaultCorsMethods, allowedHeaders: defaultCorsHeaders, allowedOrigins: []string{}, optionStatusCode: defaultCorsOptionStatusCode, } ch.setAllowedOrigins(opts.AllowedOrigins) ch.setAllowedHeaders(opts.AllowedHeaders) ch.setAllowedMethods(opts.AllowedMethods) ch.setExposedHeaders(opts.ExposedHeaders) ch.setMaxAge(opts.MaxAge) ch.maxAge = opts.MaxAge ch.allowCredentials = opts.AllowCredentials return ch } } func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get(corsOriginHeader) if !ch.isOriginAllowed(origin) { if r.Method != corsOptionMethod || ch.ignoreOptions { ch.h.ServeHTTP(w, r) } return } if r.Method == corsOptionMethod { if ch.ignoreOptions { ch.h.ServeHTTP(w, r) return } if _, ok := r.Header[corsRequestMethodHeader]; !ok { w.WriteHeader(http.StatusBadRequest) return } method := r.Header.Get(corsRequestMethodHeader) if !ch.isMatch(method, ch.allowedMethods) { w.WriteHeader(http.StatusMethodNotAllowed) return } requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") var allowedHeaders []string for _, v := range requestHeaders { canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { continue } if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { w.WriteHeader(http.StatusForbidden) return } allowedHeaders = append(allowedHeaders, canonicalHeader) } if len(allowedHeaders) > 0 { w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) } if ch.maxAge > 0 { w.Header().Set(corsMaxAgeHeader, strconv.Itoa(int(ch.maxAge))) } if !ch.isMatch(method, defaultCorsMethods) { w.Header().Set(corsAllowMethodsHeader, method) } } else if len(ch.exposedHeaders) > 0 { w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) } if ch.allowCredentials { w.Header().Set(corsAllowCredentialsHeader, "true") } if len(ch.allowedOrigins) > 1 { w.Header().Set(corsVaryHeader, corsOriginHeader) } returnOrigin := origin if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { returnOrigin = "*" } else { for _, o := range ch.allowedOrigins { // A configuration of * is different from explicitly setting an allowed // origin. Returning arbitrary origin headers in an access control allow // origin header is unsafe and is not required by any use case. if o == corsOriginMatchAll { returnOrigin = "*" break } } } w.Header().Set(corsAllowOriginHeader, returnOrigin) if r.Method == corsOptionMethod { w.WriteHeader(ch.optionStatusCode) return } ch.h.ServeHTTP(w, r) } // AllowedOrigins sets the allowed origins for CORS requests, as used in the // 'Allow-Access-Control-Origin' HTTP header. // Note: Passing in a []string{"*"} will allow any domain. func (ch *cors) setAllowedOrigins(origins []string) { // look for "*" for _, v := range origins { if v == corsOriginMatchAll { ch.allowedOrigins = []string{corsOriginMatchAll} return } } ch.allowedOrigins = origins } // setAllowedHeaders adds the provided headers to the list of allowed headers in a // CORS request. // This is an appended operation, so the headers Accept, Accept-Language, // and Content-Language are always allowed. // Content-Type must be explicitly declared if accepting Content-Types other than // application/x-www-form-urlencoded, multipart/form-data, or text/plain. func (ch *cors) setAllowedHeaders(headers []string) { for _, v := range headers { normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) if normalizedHeader == "" { continue } if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) } } } // AllowedMethods can be used to explicitly allow methods in the // Access-Control-Allow-Methods header. // This is a replacement operation, so you must also // pass GET, HEAD, and POST if you wish to support those methods. func (ch *cors) setAllowedMethods(methods []string) { if len(methods) == 0 { return } ch.allowedMethods = []string{} for _, v := range methods { normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) if normalizedMethod == "" { continue } if !ch.isMatch(normalizedMethod, ch.allowedMethods) { ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) } } } // ExposedHeaders can be used to specify headers that are available // and will not be stripped out by the user-agent. func (ch *cors) setExposedHeaders(headers []string) { ch.exposedHeaders = []string{} for _, v := range headers { normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) if normalizedHeader == "" { continue } if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) } } } // MaxAge determines the maximum age (in seconds) between preflight requests. A // maximum of 10 minutes is allowed. An age above this value will default to 10 // minutes. func (ch *cors) setMaxAge(age uint) { // Maximum of 10 minutes. if age > 600 { age = 600 } ch.maxAge = age } func (ch *cors) isOriginAllowed(origin string) bool { if origin == "" { return false } if ch.allowedOriginValidator != nil { return ch.allowedOriginValidator(origin) } if len(ch.allowedOrigins) == 0 { return true } for _, allowedOrigin := range ch.allowedOrigins { if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { return true } } return false } func (ch *cors) isMatch(needle string, haystack []string) bool { for _, v := range haystack { if v == needle { return true } } return false }