Commit 214b8c05 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'ak/health-checks-format' into 'master'

Fix health checks routes incorrectly intercepting errors

See merge request gitlab-org/gitlab-workhorse!424
parents 6f5b7b91 8ec0f244
package staticpages
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
......@@ -21,6 +22,14 @@ var (
)
)
type ErrorFormat int
const (
ErrorFormatHTML ErrorFormat = iota
ErrorFormatJSON
ErrorFormatText
)
func init() {
prometheus.MustRegister(staticErrorResponses)
}
......@@ -30,6 +39,7 @@ type errorPageResponseWriter struct {
status int
hijacked bool
path string
format ErrorFormat
}
func (s *errorPageResponseWriter) Header() http.Header {
......@@ -53,41 +63,77 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status
if 400 <= s.status && s.status <= 599 &&
s.rw.Header().Get("X-GitLab-Custom-Error") == "" &&
s.rw.Header().Get("Content-Type") != "application/json" {
if s.status < 400 || s.status > 599 || s.rw.Header().Get("X-GitLab-Custom-Error") != "" {
s.rw.WriteHeader(status)
return
}
var contentType string
var data []byte
switch s.format {
case ErrorFormatText:
contentType, data = s.writeText()
case ErrorFormatJSON:
contentType, data = s.writeJSON()
default:
contentType, data = s.writeHTML()
}
if contentType == "" {
s.rw.WriteHeader(status)
return
}
s.hijacked = true
staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc()
helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", contentType)
s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
s.rw.Header().Del("Transfer-Encoding")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
}
func (s *errorPageResponseWriter) writeHTML() (string, []byte) {
if s.rw.Header().Get("Content-Type") != "application/json" {
errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.hijacked = true
staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc()
helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
s.rw.Header().Del("Transfer-Encoding")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
return
return "text/html; charset=utf-8", data
}
}
s.rw.WriteHeader(status)
return "", nil
}
func (s *errorPageResponseWriter) writeJSON() (string, []byte) {
message, err := json.Marshal(map[string]interface{}{"error": http.StatusText(s.status), "status": s.status})
if err != nil {
return "", nil
}
return "application/json; charset=utf-8", append(message, "\n"...)
}
func (s *errorPageResponseWriter) writeText() (string, []byte) {
return "text/plain; charset=utf-8", []byte(http.StatusText(s.status) + "\n")
}
func (s *errorPageResponseWriter) flush() {
s.WriteHeader(http.StatusOK)
}
func (st *Static) ErrorPagesUnless(disabled bool, handler http.Handler) http.Handler {
func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler {
if disabled {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{
rw: w,
path: st.DocumentRoot,
rw: w,
path: st.DocumentRoot,
format: format,
}
defer rw.flush()
handler.ServeHTTP(&rw, r)
......
......@@ -33,11 +33,12 @@ func TestIfErrorPageIsPresented(t *testing.T) {
require.Equal(t, len(upstreamBody), n, "bytes written")
})
st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 404)
testhelper.AssertResponseBody(t, w, errorPage)
testhelper.AssertResponseHeader(t, w, "Content-Type", "text/html; charset=utf-8")
}
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
......@@ -54,7 +55,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
fmt.Fprint(w, errorResponse)
})
st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 404)
......@@ -78,7 +79,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
fmt.Fprint(w, serverError)
})
st := &Static{dir}
st.ErrorPagesUnless(true, h).ServeHTTP(w, nil)
st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 500)
testhelper.AssertResponseBody(t, w, serverError)
......@@ -102,7 +103,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) {
fmt.Fprint(w, serverError)
})
st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 500)
testhelper.AssertResponseBody(t, w, serverError)
......@@ -137,7 +138,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
fmt.Fprint(w, serverError)
})
st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 500)
......@@ -148,3 +149,43 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
}
}
}
func TestIfErrorPageIsPresentedJSON(t *testing.T) {
errorPage := "{\"error\":\"Not Found\",\"status\":404}\n"
w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
upstreamBody := "This string is ignored"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
})
st := &Static{""}
st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 404)
testhelper.AssertResponseBody(t, w, errorPage)
testhelper.AssertResponseHeader(t, w, "Content-Type", "application/json; charset=utf-8")
}
func TestIfErrorPageIsPresentedText(t *testing.T) {
errorPage := "Not Found\n"
w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
upstreamBody := "This string is ignored"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
})
st := &Static{""}
st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil)
w.Flush()
testhelper.AssertResponseCode(t, w, 404)
testhelper.AssertResponseBody(t, w, errorPage)
testhelper.AssertResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8")
}
......@@ -174,8 +174,10 @@ func (u *upstream) configureRoutes() {
defaultUpstream := static.ServeExisting(
u.URLPrefix,
staticpages.CacheDisabled,
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, uploadAccelerateProxy)),
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, uploadAccelerateProxy)),
)
probeUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatJSON, proxy)
healthUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatText, proxy)
u.Routes = []routeEntry{
// Git Clone
......@@ -235,7 +237,13 @@ func (u *upstream) configureRoutes() {
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, proxy)),
route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, proxy)),
// health checks don't intercept errors and go straight to rails
// TODO: We should probably not return a HTML deploy page?
// https://gitlab.com/gitlab-org/gitlab-workhorse/issues/230
route("", "^/-/(readiness|liveness)$", static.DeployPage(probeUpstream)),
route("", "^/-/health$", static.DeployPage(healthUpstream)),
// This route lets us filter out health checks from our metrics.
route("", "^/-/", defaultUpstream),
......
......@@ -615,3 +615,67 @@ func assertNginxResponseBuffering(t *testing.T, expected string, resp *http.Resp
actual := resp.Header.Get(helper.NginxResponseBufferHeader)
assert.Equal(t, expected, actual, msgAndArgs...)
}
// TestHealthChecksNoStaticHTML verifies that health endpoints pass errors through and don't return the static html error pages
func TestHealthChecksNoStaticHTML(t *testing.T) {
apiResponse := "API RESPONSE"
errorPageBody := `<html>
<body>
This is a static error page for code 503
</body>
</html>
`
require.NoError(t, setupStaticFile("503.html", errorPageBody))
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("X-Gitlab-Custom-Error", "1")
w.WriteHeader(503)
_, err := w.Write([]byte(apiResponse))
require.NoError(t, err)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
for _, resource := range []string{
"/-/health",
"/-/readiness",
"/-/liveness",
} {
t.Run(resource, func(t *testing.T) {
resp, body := httpGet(t, ws.URL+resource, nil)
assert.Equal(t, 503, resp.StatusCode, "status code")
assert.Equal(t, apiResponse, body, "response body")
assertNginxResponseBuffering(t, "", resp, "nginx response buffering")
})
}
}
// TestHealthChecksUnreachable verifies that health endpoints return the correct content-type when the upstream is down
func TestHealthChecksUnreachable(t *testing.T) {
ws := startWorkhorseServer("http://127.0.0.1:99999") // This url should point to nothing for the test to be accurate (equivalent to upstream being down)
defer ws.Close()
testCases := []struct {
path string
content string
responseType string
}{
{path: "/-/health", content: "Bad Gateway\n", responseType: "text/plain; charset=utf-8"},
{path: "/-/readiness", content: "{\"error\":\"Bad Gateway\",\"status\":502}\n", responseType: "application/json; charset=utf-8"},
{path: "/-/liveness", content: "{\"error\":\"Bad Gateway\",\"status\":502}\n", responseType: "application/json; charset=utf-8"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, body := httpGet(t, ws.URL+tc.path, nil)
assert.Equal(t, 502, resp.StatusCode, "status code")
assert.Equal(t, tc.responseType, resp.Header.Get("Content-Type"), "content-type")
assert.Equal(t, tc.content, body, "response body")
assertNginxResponseBuffering(t, "", resp, "nginx response buffering")
})
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment