Commit b03e9b9f authored by Nick Thomas's avatar Nick Thomas

Add helper.IsContentType and use it everywhere

parent 2de1f1e3
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"mime"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
...@@ -288,6 +287,5 @@ func bufferResponse(r io.Reader) (*bytes.Buffer, error) { ...@@ -288,6 +287,5 @@ func bufferResponse(r io.Reader) (*bytes.Buffer, error) {
} }
func validResponseContentType(resp *http.Response) bool { func validResponseContentType(resp *http.Response) bool {
parsed, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) return helper.IsContentType(ResponseContentType, resp.Header.Get("Content-Type"))
return err == nil && parsed == ResponseContentType
} }
...@@ -44,7 +44,7 @@ func (b *blocker) WriteHeader(status int) { ...@@ -44,7 +44,7 @@ func (b *blocker) WriteHeader(status int) {
return return
} }
if b.Header().Get("Content-Type") == ResponseContentType { if helper.IsContentType(ResponseContentType, b.Header().Get("Content-Type")) {
b.status = 500 b.status = 500
b.Header().Del("Content-Length") b.Header().Del("Content-Length")
b.hijacked = true b.hijacked = true
......
...@@ -3,6 +3,7 @@ package helper ...@@ -3,6 +3,7 @@ package helper
import ( import (
"errors" "errors"
"log" "log"
"mime"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
...@@ -160,3 +161,8 @@ func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) { ...@@ -160,3 +161,8 @@ func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) {
newHeaders.Set("X-Forwarded-For", header) newHeaders.Set("X-Forwarded-For", header)
} }
} }
func IsContentType(expected, actual string) bool {
parsed, _, err := mime.ParseMediaType(actual)
return err == nil && parsed == expected
}
package upstream package upstream
import ( import (
"mime"
"net/http" "net/http"
"path" "path"
"regexp" "regexp"
...@@ -67,8 +66,7 @@ func wsRoute(regexpStr string, handler http.Handler, matchers ...matcherFunc) ro ...@@ -67,8 +66,7 @@ func wsRoute(regexpStr string, handler http.Handler, matchers ...matcherFunc) ro
// Creates matcherFuncs for a particular content type. // Creates matcherFuncs for a particular content type.
func isContentType(contentType string) func(*http.Request) bool { func isContentType(contentType string) func(*http.Request) bool {
return func(r *http.Request) bool { return func(r *http.Request) bool {
parsed, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) return helper.IsContentType(contentType, r.Header.Get("Content-Type"))
return err == nil && contentType == parsed
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment