Commit e1175065 authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Nick Thomas

Update CI matrix to go1.10 + go1.11 and fix ResponseWriter bugs

parent 98d76e5c
image: golang:1.9 image: golang:1.10
verify: verify:
image: golang:1.10
script: script:
- make verify - make verify
...@@ -18,12 +17,11 @@ verify: ...@@ -18,12 +17,11 @@ verify:
- go version - go version
- make test - make test
test using go 1.9: test using go 1.10:
image: golang:1.9
<<: *test_definition <<: *test_definition
test using go 1.10: test using go 1.11:
image: golang:1.10 image: golang:1.11
<<: *test_definition <<: *test_definition
test:release: test:release:
......
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
func Block(h http.Handler) http.Handler { func Block(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := &blocker{rw: w, r: r} rw := &blocker{rw: w, r: r}
defer rw.Flush() defer rw.flush()
h.ServeHTTP(rw, r) h.ServeHTTP(rw, r)
}) })
} }
...@@ -33,7 +33,7 @@ func (b *blocker) Write(data []byte) (int, error) { ...@@ -33,7 +33,7 @@ func (b *blocker) Write(data []byte) (int, error) {
b.WriteHeader(http.StatusOK) b.WriteHeader(http.StatusOK)
} }
if b.hijacked { if b.hijacked {
return 0, nil return len(data), nil
} }
return b.rw.Write(data) return b.rw.Write(data)
...@@ -56,6 +56,6 @@ func (b *blocker) WriteHeader(status int) { ...@@ -56,6 +56,6 @@ func (b *blocker) WriteHeader(status int) {
b.rw.WriteHeader(b.status) b.rw.WriteHeader(b.status)
} }
func (b *blocker) Flush() { func (b *blocker) flush() {
b.WriteHeader(http.StatusOK) b.WriteHeader(http.StatusOK)
} }
package api
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestBlocker(t *testing.T) {
upstreamResponse := "hello world"
testCases := []struct {
desc string
contentType string
out string
}{
{
desc: "blocked",
contentType: ResponseContentType,
out: "Internal server error\n",
},
{
desc: "pass",
contentType: "text/plain",
out: upstreamResponse,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
r, err := http.NewRequest("GET", "/foo", nil)
require.NoError(t, err)
rw := httptest.NewRecorder()
bl := &blocker{rw: rw, r: r}
bl.Header().Set("Content-Type", tc.contentType)
upstreamBody := []byte(upstreamResponse)
n, err := bl.Write(upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
rw.Flush()
body := rw.Result().Body
data, err := ioutil.ReadAll(body)
require.NoError(t, err)
require.NoError(t, body.Close())
require.Equal(t, tc.out, string(data))
})
}
}
...@@ -24,12 +24,12 @@ func (c *countingResponseWriter) Header() http.Header { ...@@ -24,12 +24,12 @@ func (c *countingResponseWriter) Header() http.Header {
return c.rw.Header() return c.rw.Header()
} }
func (c *countingResponseWriter) Write(data []byte) (n int, err error) { func (c *countingResponseWriter) Write(data []byte) (int, error) {
if c.status == 0 { if c.status == 0 {
c.WriteHeader(http.StatusOK) c.WriteHeader(http.StatusOK)
} }
n, err = c.rw.Write(data) n, err := c.rw.Write(data)
c.count += int64(n) c.count += int64(n)
return n, err return n, err
} }
......
...@@ -62,11 +62,11 @@ func (l *statsCollectingResponseWriter) Header() http.Header { ...@@ -62,11 +62,11 @@ func (l *statsCollectingResponseWriter) Header() http.Header {
return l.rw.Header() return l.rw.Header()
} }
func (l *statsCollectingResponseWriter) Write(data []byte) (n int, err error) { func (l *statsCollectingResponseWriter) Write(data []byte) (int, error) {
if !l.wroteHeader { if !l.wroteHeader {
l.WriteHeader(http.StatusOK) l.WriteHeader(http.StatusOK)
} }
n, err = l.rw.Write(data) n, err := l.rw.Write(data)
l.written += int64(n) l.written += int64(n)
return n, err return n, err
......
...@@ -45,7 +45,7 @@ func SendData(h http.Handler, injecters ...Injecter) http.Handler { ...@@ -45,7 +45,7 @@ func SendData(h http.Handler, injecters ...Injecter) http.Handler {
req: r, req: r,
injecters: injecters, injecters: injecters,
} }
defer s.Flush() defer s.flush()
h.ServeHTTP(&s, r) h.ServeHTTP(&s, r)
}) })
} }
...@@ -54,12 +54,12 @@ func (s *sendDataResponseWriter) Header() http.Header { ...@@ -54,12 +54,12 @@ func (s *sendDataResponseWriter) Header() http.Header {
return s.rw.Header() return s.rw.Header()
} }
func (s *sendDataResponseWriter) Write(data []byte) (n int, err error) { func (s *sendDataResponseWriter) Write(data []byte) (int, error) {
if s.status == 0 { if s.status == 0 {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
if s.hijacked { if s.hijacked {
return return len(data), nil
} }
return s.rw.Write(data) return s.rw.Write(data)
} }
...@@ -100,6 +100,6 @@ func (s *sendDataResponseWriter) tryInject() bool { ...@@ -100,6 +100,6 @@ func (s *sendDataResponseWriter) tryInject() bool {
return false return false
} }
func (s *sendDataResponseWriter) Flush() { func (s *sendDataResponseWriter) flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
package senddata package senddata
import ( import (
"io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestHeaderDelete(t *testing.T) { func TestHeaderDelete(t *testing.T) {
...@@ -18,3 +23,60 @@ func TestHeaderDelete(t *testing.T) { ...@@ -18,3 +23,60 @@ func TestHeaderDelete(t *testing.T) {
} }
} }
} }
func TestWriter(t *testing.T) {
upstreamResponse := "hello world"
testCases := []struct {
desc string
headerValue string
out string
}{
{
desc: "inject",
headerValue: testInjecterName + ":" + testInjecterName,
out: testInjecterData,
},
{
desc: "pass",
headerValue: "",
out: upstreamResponse,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
recorder := httptest.NewRecorder()
rw := &sendDataResponseWriter{rw: recorder, injecters: []Injecter{&testInjecter{}}}
rw.Header().Set(HeaderKey, tc.headerValue)
n, err := rw.Write([]byte(upstreamResponse))
require.NoError(t, err)
require.Equal(t, len(upstreamResponse), n, "bytes written")
recorder.Flush()
body := recorder.Result().Body
data, err := ioutil.ReadAll(body)
require.NoError(t, err)
require.NoError(t, body.Close())
require.Equal(t, tc.out, string(data))
})
}
}
const (
testInjecterName = "test-injecter"
testInjecterData = "hello this is injected data"
)
type testInjecter struct{}
func (ti *testInjecter) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
io.WriteString(w, testInjecterData)
}
func (ti *testInjecter) Match(s string) bool { return strings.HasPrefix(s, testInjecterName+":") }
func (ti *testInjecter) Name() string { return testInjecterName }
...@@ -58,7 +58,7 @@ func SendFile(h http.Handler) http.Handler { ...@@ -58,7 +58,7 @@ func SendFile(h http.Handler) http.Handler {
} }
// Advertise to upstream (Rails) that we support X-Sendfile // Advertise to upstream (Rails) that we support X-Sendfile
req.Header.Set("X-Sendfile-Type", "X-Sendfile") req.Header.Set("X-Sendfile-Type", "X-Sendfile")
defer s.Flush() defer s.flush()
h.ServeHTTP(s, req) h.ServeHTTP(s, req)
}) })
} }
...@@ -67,12 +67,12 @@ func (s *sendFileResponseWriter) Header() http.Header { ...@@ -67,12 +67,12 @@ func (s *sendFileResponseWriter) Header() http.Header {
return s.rw.Header() return s.rw.Header()
} }
func (s *sendFileResponseWriter) Write(data []byte) (n int, err error) { func (s *sendFileResponseWriter) Write(data []byte) (int, error) {
if s.status == 0 { if s.status == 0 {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
if s.hijacked { if s.hijacked {
return return len(data), nil
} }
return s.rw.Write(data) return s.rw.Write(data)
} }
...@@ -134,6 +134,6 @@ func countSendFileMetrics(size int64, r *http.Request) { ...@@ -134,6 +134,6 @@ func countSendFileMetrics(size int64, r *http.Request) {
sendFileBytes.WithLabelValues(requestType).Add(float64(size)) sendFileBytes.WithLabelValues(requestType).Add(float64(size))
} }
func (s *sendFileResponseWriter) Flush() { func (s *sendFileResponseWriter) flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
package sendfile
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestResponseWriter(t *testing.T) {
upstreamResponse := "hello world"
fixturePath := "testdata/sent-file.txt"
fixtureContent, err := ioutil.ReadFile(fixturePath)
require.NoError(t, err)
testCases := []struct {
desc string
sendfileHeader string
out string
}{
{
desc: "send a file",
sendfileHeader: fixturePath,
out: string(fixtureContent),
},
{
desc: "pass through unaltered",
sendfileHeader: "",
out: upstreamResponse,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
r, err := http.NewRequest("GET", "/foo", nil)
require.NoError(t, err)
rw := httptest.NewRecorder()
sf := &sendFileResponseWriter{rw: rw, req: r}
sf.Header().Set(sendFileResponseHeader, tc.sendfileHeader)
upstreamBody := []byte(upstreamResponse)
n, err := sf.Write(upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
rw.Flush()
body := rw.Result().Body
data, err := ioutil.ReadAll(body)
require.NoError(t, err)
require.NoError(t, body.Close())
require.Equal(t, tc.out, string(data))
})
}
}
This file is sent with X-SendFile
...@@ -36,12 +36,12 @@ func (s *errorPageResponseWriter) Header() http.Header { ...@@ -36,12 +36,12 @@ func (s *errorPageResponseWriter) Header() http.Header {
return s.rw.Header() return s.rw.Header()
} }
func (s *errorPageResponseWriter) Write(data []byte) (n int, err error) { func (s *errorPageResponseWriter) Write(data []byte) (int, error) {
if s.status == 0 { if s.status == 0 {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
if s.hijacked { if s.hijacked {
return 0, nil return len(data), nil
} }
return s.rw.Write(data) return s.rw.Write(data)
} }
...@@ -76,7 +76,7 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { ...@@ -76,7 +76,7 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.rw.WriteHeader(status) s.rw.WriteHeader(status)
} }
func (s *errorPageResponseWriter) Flush() { func (s *errorPageResponseWriter) flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
...@@ -89,7 +89,7 @@ func (st *Static) ErrorPagesUnless(disabled bool, handler http.Handler) http.Han ...@@ -89,7 +89,7 @@ func (st *Static) ErrorPagesUnless(disabled bool, handler http.Handler) http.Han
rw: w, rw: w,
path: st.DocumentRoot, path: st.DocumentRoot,
} }
defer rw.Flush() defer rw.flush()
handler.ServeHTTP(&rw, r) handler.ServeHTTP(&rw, r)
}) })
} }
...@@ -9,6 +9,8 @@ import ( ...@@ -9,6 +9,8 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
) )
...@@ -25,7 +27,10 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -25,7 +27,10 @@ func TestIfErrorPageIsPresented(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, "Not Found") upstreamBody := "Not Found"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
}) })
st := &Static{dir} st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment