Commit 8ec0f244 authored by Adrien Kohlbecker's avatar Adrien Kohlbecker Committed by Jacob Vosmaer

Fix health checks routes incorrectly intercepting errors

parent 6f5b7b91
package staticpages package staticpages
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
...@@ -21,6 +22,14 @@ var ( ...@@ -21,6 +22,14 @@ var (
) )
) )
type ErrorFormat int
const (
ErrorFormatHTML ErrorFormat = iota
ErrorFormatJSON
ErrorFormatText
)
func init() { func init() {
prometheus.MustRegister(staticErrorResponses) prometheus.MustRegister(staticErrorResponses)
} }
...@@ -30,6 +39,7 @@ type errorPageResponseWriter struct { ...@@ -30,6 +39,7 @@ type errorPageResponseWriter struct {
status int status int
hijacked bool hijacked bool
path string path string
format ErrorFormat
} }
func (s *errorPageResponseWriter) Header() http.Header { func (s *errorPageResponseWriter) Header() http.Header {
...@@ -53,41 +63,77 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { ...@@ -53,41 +63,77 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status s.status = status
if 400 <= s.status && s.status <= 599 && if s.status < 400 || s.status > 599 || s.rw.Header().Get("X-GitLab-Custom-Error") != "" {
s.rw.Header().Get("X-GitLab-Custom-Error") == "" && s.rw.WriteHeader(status)
s.rw.Header().Get("Content-Type") != "application/json" { return
}
var contentType string
var data []byte
switch s.format {
case ErrorFormatText:
contentType, data = s.writeText()
case ErrorFormatJSON:
contentType, data = s.writeJSON()
default:
contentType, data = s.writeHTML()
}
if contentType == "" {
s.rw.WriteHeader(status)
return
}
s.hijacked = true
staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc()
helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", contentType)
s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
s.rw.Header().Del("Transfer-Encoding")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
}
func (s *errorPageResponseWriter) writeHTML() (string, []byte) {
if s.rw.Header().Get("Content-Type") != "application/json" {
errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status)) errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead // check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil { if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.hijacked = true return "text/html; charset=utf-8", data
staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc()
helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
s.rw.Header().Del("Transfer-Encoding")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
return
} }
} }
s.rw.WriteHeader(status) return "", nil
}
func (s *errorPageResponseWriter) writeJSON() (string, []byte) {
message, err := json.Marshal(map[string]interface{}{"error": http.StatusText(s.status), "status": s.status})
if err != nil {
return "", nil
}
return "application/json; charset=utf-8", append(message, "\n"...)
}
func (s *errorPageResponseWriter) writeText() (string, []byte) {
return "text/plain; charset=utf-8", []byte(http.StatusText(s.status) + "\n")
} }
func (s *errorPageResponseWriter) flush() { func (s *errorPageResponseWriter) flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func (st *Static) ErrorPagesUnless(disabled bool, handler http.Handler) http.Handler { func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler {
if disabled { if disabled {
return handler return handler
} }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{ rw := errorPageResponseWriter{
rw: w, rw: w,
path: st.DocumentRoot, path: st.DocumentRoot,
format: format,
} }
defer rw.flush() defer rw.flush()
handler.ServeHTTP(&rw, r) handler.ServeHTTP(&rw, r)
......
...@@ -33,11 +33,12 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -33,11 +33,12 @@ func TestIfErrorPageIsPresented(t *testing.T) {
require.Equal(t, len(upstreamBody), n, "bytes written") require.Equal(t, len(upstreamBody), n, "bytes written")
}) })
st := &Static{dir} st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
testhelper.AssertResponseCode(t, w, 404) testhelper.AssertResponseCode(t, w, 404)
testhelper.AssertResponseBody(t, w, errorPage) testhelper.AssertResponseBody(t, w, errorPage)
testhelper.AssertResponseHeader(t, w, "Content-Type", "text/html; charset=utf-8")
} }
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
...@@ -54,7 +55,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { ...@@ -54,7 +55,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
fmt.Fprint(w, errorResponse) fmt.Fprint(w, errorResponse)
}) })
st := &Static{dir} st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
testhelper.AssertResponseCode(t, w, 404) testhelper.AssertResponseCode(t, w, 404)
...@@ -78,7 +79,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { ...@@ -78,7 +79,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
}) })
st := &Static{dir} st := &Static{dir}
st.ErrorPagesUnless(true, h).ServeHTTP(w, nil) st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
testhelper.AssertResponseCode(t, w, 500) testhelper.AssertResponseCode(t, w, 500)
testhelper.AssertResponseBody(t, w, serverError) testhelper.AssertResponseBody(t, w, serverError)
...@@ -102,7 +103,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) { ...@@ -102,7 +103,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) {
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
}) })
st := &Static{dir} st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
testhelper.AssertResponseCode(t, w, 500) testhelper.AssertResponseCode(t, w, 500)
testhelper.AssertResponseBody(t, w, serverError) testhelper.AssertResponseBody(t, w, serverError)
...@@ -137,7 +138,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) { ...@@ -137,7 +138,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
}) })
st := &Static{dir} st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
testhelper.AssertResponseCode(t, w, 500) testhelper.AssertResponseCode(t, w, 500)
...@@ -148,3 +149,43 @@ func TestErrorPageInterceptedByContentType(t *testing.T) { ...@@ -148,3 +149,43 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
} }
} }
} }
func TestIfErrorPageIsPresentedJSON(t *testing.T) {
errorPage := "{\"error\":\"Not Found\",\"status\":404}\n"
w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
upstreamBody := "This string is ignored"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
})
st := &Static{""}
st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 404)
testhelper.AssertResponseBody(t, w, errorPage)
testhelper.AssertResponseHeader(t, w, "Content-Type", "application/json; charset=utf-8")
}
func TestIfErrorPageIsPresentedText(t *testing.T) {
errorPage := "Not Found\n"
w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
upstreamBody := "This string is ignored"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
})
st := &Static{""}
st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 404)
testhelper.AssertResponseBody(t, w, errorPage)
testhelper.AssertResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8")
}
...@@ -174,8 +174,10 @@ func (u *upstream) configureRoutes() { ...@@ -174,8 +174,10 @@ func (u *upstream) configureRoutes() {
defaultUpstream := static.ServeExisting( defaultUpstream := static.ServeExisting(
u.URLPrefix, u.URLPrefix,
staticpages.CacheDisabled, staticpages.CacheDisabled,
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, uploadAccelerateProxy)), static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, uploadAccelerateProxy)),
) )
probeUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatJSON, proxy)
healthUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatText, proxy)
u.Routes = []routeEntry{ u.Routes = []routeEntry{
// Git Clone // Git Clone
...@@ -235,7 +237,13 @@ func (u *upstream) configureRoutes() { ...@@ -235,7 +237,13 @@ func (u *upstream) configureRoutes() {
// To prevent anybody who knows/guesses the URL of a user-uploaded file // To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass // from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting. // through static.ServeExisting.
route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, proxy)), route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, proxy)),
// health checks don't intercept errors and go straight to rails
// TODO: We should probably not return a HTML deploy page?
// https://gitlab.com/gitlab-org/gitlab-workhorse/issues/230
route("", "^/-/(readiness|liveness)$", static.DeployPage(probeUpstream)),
route("", "^/-/health$", static.DeployPage(healthUpstream)),
// This route lets us filter out health checks from our metrics. // This route lets us filter out health checks from our metrics.
route("", "^/-/", defaultUpstream), route("", "^/-/", defaultUpstream),
......
...@@ -615,3 +615,67 @@ func assertNginxResponseBuffering(t *testing.T, expected string, resp *http.Resp ...@@ -615,3 +615,67 @@ func assertNginxResponseBuffering(t *testing.T, expected string, resp *http.Resp
actual := resp.Header.Get(helper.NginxResponseBufferHeader) actual := resp.Header.Get(helper.NginxResponseBufferHeader)
assert.Equal(t, expected, actual, msgAndArgs...) assert.Equal(t, expected, actual, msgAndArgs...)
} }
// TestHealthChecksNoStaticHTML verifies that health endpoints pass errors through and don't return the static html error pages
func TestHealthChecksNoStaticHTML(t *testing.T) {
apiResponse := "API RESPONSE"
errorPageBody := `<html>
<body>
This is a static error page for code 503
</body>
</html>
`
require.NoError(t, setupStaticFile("503.html", errorPageBody))
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("X-Gitlab-Custom-Error", "1")
w.WriteHeader(503)
_, err := w.Write([]byte(apiResponse))
require.NoError(t, err)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
for _, resource := range []string{
"/-/health",
"/-/readiness",
"/-/liveness",
} {
t.Run(resource, func(t *testing.T) {
resp, body := httpGet(t, ws.URL+resource, nil)
assert.Equal(t, 503, resp.StatusCode, "status code")
assert.Equal(t, apiResponse, body, "response body")
assertNginxResponseBuffering(t, "", resp, "nginx response buffering")
})
}
}
// TestHealthChecksUnreachable verifies that health endpoints return the correct content-type when the upstream is down
func TestHealthChecksUnreachable(t *testing.T) {
ws := startWorkhorseServer("http://127.0.0.1:99999") // This url should point to nothing for the test to be accurate (equivalent to upstream being down)
defer ws.Close()
testCases := []struct {
path string
content string
responseType string
}{
{path: "/-/health", content: "Bad Gateway\n", responseType: "text/plain; charset=utf-8"},
{path: "/-/readiness", content: "{\"error\":\"Bad Gateway\",\"status\":502}\n", responseType: "application/json; charset=utf-8"},
{path: "/-/liveness", content: "{\"error\":\"Bad Gateway\",\"status\":502}\n", responseType: "application/json; charset=utf-8"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, body := httpGet(t, ws.URL+tc.path, nil)
assert.Equal(t, 502, resp.StatusCode, "status code")
assert.Equal(t, tc.responseType, resp.Header.Get("Content-Type"), "content-type")
assert.Equal(t, tc.content, body, "response body")
assertNginxResponseBuffering(t, "", resp, "nginx response buffering")
})
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment