Commit 3b33cf3d authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Alessio Caiazza

Use testify/require in top level tests

parent 9a4a1d48
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
...@@ -31,9 +32,7 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur ...@@ -31,9 +32,7 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur
// Create http request // Create http request
httpRequest, err := http.NewRequest("GET", "/address", nil) httpRequest, err := http.NewRequest("GET", "/address", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
parsedURL := helper.URLMustParse(ts.URL) parsedURL := helper.URLMustParse(ts.URL)
testhelper.ConfigureSecret() testhelper.ConfigureSecret()
a := api.NewAPI(parsedURL, "123", roundtripper.NewTestBackendRoundTripper(parsedURL)) a := api.NewAPI(parsedURL, "123", roundtripper.NewTestBackendRoundTripper(parsedURL))
...@@ -70,9 +69,8 @@ func TestPreAuthorizeJsonFailure(t *testing.T) { ...@@ -70,9 +69,8 @@ func TestPreAuthorizeJsonFailure(t *testing.T) {
func TestPreAuthorizeContentTypeFailure(t *testing.T) { func TestPreAuthorizeContentTypeFailure(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte(`{"hello":"world"}`)); err != nil { _, err := w.Write([]byte(`{"hello":"world"}`))
t.Fatalf("write auth response: %v", err) require.NoError(t, err, "write auth response")
}
})) }))
defer ts.Close() defer ts.Close()
...@@ -110,27 +108,16 @@ func TestPreAuthorizeJWT(t *testing.T) { ...@@ -110,27 +108,16 @@ func TestPreAuthorizeJWT(t *testing.T) {
return secretBytes, nil return secretBytes, nil
}) })
if err != nil { require.NoError(t, err, "decode token")
t.Fatalf("decode token: %v", err)
}
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok { require.True(t, ok, "claims cast")
t.Fatal("claims cast failed") require.True(t, token.Valid, "JWT token valid")
} require.Equal(t, "gitlab-workhorse", claims["iss"], "JWT token issuer")
if !token.Valid {
t.Fatal("JWT token invalid")
}
if claims["iss"] != "gitlab-workhorse" {
t.Fatalf("execpted issuer gitlab-workhorse, got %q", claims["iss"])
}
w.Header().Set("Content-Type", api.ResponseContentType) w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := w.Write([]byte(`{"hello":"world"}`)); err != nil { _, err = w.Write([]byte(`{"hello":"world"}`))
t.Fatalf("write auth response: %v", err) require.NoError(t, err, "write auth response")
}
})) }))
defer ts.Close() defer ts.Close()
......
...@@ -2,9 +2,11 @@ package main ...@@ -2,9 +2,11 @@ package main
import ( import (
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestParseAuthBackend(t *testing.T) { func TestParseAuthBackendFailure(t *testing.T) {
failures := []string{ failures := []string{
"", "",
"ftp://localhost", "ftp://localhost",
...@@ -12,11 +14,14 @@ func TestParseAuthBackend(t *testing.T) { ...@@ -12,11 +14,14 @@ func TestParseAuthBackend(t *testing.T) {
} }
for _, example := range failures { for _, example := range failures {
if _, err := parseAuthBackend(example); err == nil { t.Run(example, func(t *testing.T) {
t.Errorf("error expected for %q", example) _, err := parseAuthBackend(example)
} require.Error(t, err)
})
} }
}
func TestParseAuthBackend(t *testing.T) {
successes := []struct{ input, host, scheme string }{ successes := []struct{ input, host, scheme string }{
{"http://localhost:8080", "localhost:8080", "http"}, {"http://localhost:8080", "localhost:8080", "http"},
{"localhost:3000", "localhost:3000", "http"}, {"localhost:3000", "localhost:3000", "http"},
...@@ -25,18 +30,12 @@ func TestParseAuthBackend(t *testing.T) { ...@@ -25,18 +30,12 @@ func TestParseAuthBackend(t *testing.T) {
} }
for _, example := range successes { for _, example := range successes {
result, err := parseAuthBackend(example.input) t.Run(example.input, func(t *testing.T) {
if err != nil { result, err := parseAuthBackend(example.input)
t.Errorf("parse %q: %v", example.input, err) require.NoError(t, err)
break
}
if result.Host != example.host {
t.Errorf("example %q: expected %q, got %q", example.input, example.host, result.Host)
}
if result.Scheme != example.scheme { require.Equal(t, example.host, result.Host, "host")
t.Errorf("example %q: expected %q, got %q", example.input, example.scheme, result.Scheme) require.Equal(t, example.scheme, result.Scheme, "scheme")
} })
} }
} }
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
...@@ -45,9 +46,7 @@ func TestChannelHappyPath(t *testing.T) { ...@@ -45,9 +46,7 @@ func TestChannelHappyPath(t *testing.T) {
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
server := (<-serverConns).conn server := (<-serverConns).conn
defer server.Close() defer server.Close()
...@@ -55,14 +54,10 @@ func TestChannelHappyPath(t *testing.T) { ...@@ -55,14 +54,10 @@ func TestChannelHappyPath(t *testing.T) {
message := "test message" message := "test message"
// channel.k8s.io: server writes to channel 1, STDOUT // channel.k8s.io: server writes to channel 1, STDOUT
if err := say(server, "\x01"+message); err != nil { require.NoError(t, say(server, "\x01"+message))
t.Fatal(err)
}
requireReadMessage(t, client, websocket.BinaryMessage, message) requireReadMessage(t, client, websocket.BinaryMessage, message)
if err := say(client, message); err != nil { require.NoError(t, say(client, message))
t.Fatal(err)
}
// channel.k8s.io: client writes get put on channel 0, STDIN // channel.k8s.io: client writes get put on channel 0, STDIN
requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message) requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
...@@ -78,14 +73,8 @@ func TestChannelBadTLS(t *testing.T) { ...@@ -78,14 +73,8 @@ func TestChannelBadTLS(t *testing.T) {
_, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io") _, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") _, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != websocket.ErrBadHandshake { require.Equal(t, websocket.ErrBadHandshake, err, "unexpected error %v", err)
t.Fatalf("Expected connection to fail ErrBadHandshake, got: %v", err)
}
if err == nil {
log.Info("TLS negotiation should have failed!")
defer client.Close()
}
} }
func TestChannelSessionTimeout(t *testing.T) { func TestChannelSessionTimeout(t *testing.T) {
...@@ -93,9 +82,7 @@ func TestChannelSessionTimeout(t *testing.T) { ...@@ -93,9 +82,7 @@ func TestChannelSessionTimeout(t *testing.T) {
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
sc := <-serverConns sc := <-serverConns
defer sc.conn.Close() defer sc.conn.Close()
...@@ -103,9 +90,7 @@ func TestChannelSessionTimeout(t *testing.T) { ...@@ -103,9 +90,7 @@ func TestChannelSessionTimeout(t *testing.T) {
client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second)) client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second))
_, _, err = client.ReadMessage() _, _, err = client.ReadMessage()
if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { require.True(t, websocket.IsCloseError(err, websocket.CloseAbnormalClosure), "Client connection was not closed, got %v", err)
t.Fatalf("Client connection was not closed, got %v", err)
}
} }
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
...@@ -115,16 +100,12 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { ...@@ -115,16 +100,12 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer client.Close() defer client.Close()
sc := <-serverConns sc := <-serverConns
defer sc.conn.Close() defer sc.conn.Close()
if sc.req.Header.Get("Random-Header") != "Value" { require.Equal(t, "Value", sc.req.Header.Get("Random-Header"), "Header specified by upstream not sent to remote")
t.Fatal("Header specified by upstream not sent to remote")
}
} }
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
...@@ -134,21 +115,16 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { ...@@ -134,21 +115,16 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
hdr := make(http.Header) hdr := make(http.Header)
hdr.Set("X-Forwarded-For", "127.0.0.2") hdr.Set("X-Forwarded-For", "127.0.0.2")
client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer client.Close() defer client.Close()
clientIP, _, err := net.SplitHostPort(client.LocalAddr().String()) clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
sc := <-serverConns sc := <-serverConns
defer sc.conn.Close() defer sc.conn.Close()
if xff := sc.req.Header.Get("X-Forwarded-For"); xff != "127.0.0.2, "+clientIP { require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
t.Fatalf("X-Forwarded-For from client not sent to remote: %+v", xff)
}
} }
func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) { func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
...@@ -262,15 +238,8 @@ func say(conn *websocket.Conn, message string) error { ...@@ -262,15 +238,8 @@ func say(conn *websocket.Conn, message string) error {
func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) { func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
messageType, data, err := conn.ReadMessage() messageType, data, err := conn.ReadMessage()
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
if messageType != expectedMessageType { require.Equal(t, expectedMessageType, messageType, "message type")
t.Fatalf("Expected message, %d, got %d", expectedMessageType, messageType) require.Equal(t, expectedData, string(data), "message data")
}
if string(data) != expectedData {
t.Fatalf("Message was mangled in transit. Expected %q, got %q", expectedData, string(data))
}
} }
...@@ -167,6 +167,11 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -167,6 +167,11 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) {
close(done) close(done)
}() }()
waitDone(t, done)
}
func waitDone(t *testing.T, done chan struct{}) {
t.Helper()
select { select {
case <-done: case <-done:
return return
...@@ -252,12 +257,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) { ...@@ -252,12 +257,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
// ReaderFunc is an adapter to turn a conforming function into an io.Reader. // ReaderFunc is an adapter to turn a conforming function into an io.Reader.
...@@ -319,9 +319,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -319,9 +319,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
m.Lock() m.Lock()
requestFinished := requestReadFinished requestFinished := requestReadFinished
m.Unlock() m.Unlock()
if !requestFinished { require.True(t, requestFinished, "response written before request was fully read")
t.Fatalf("response written before request was fully read")
}
body := string(testhelper.ReadAll(t, resp.Body)) body := string(testhelper.ReadAll(t, resp.Body))
bodySplit := strings.SplitN(body, "\000", 2) bodySplit := strings.SplitN(body, "\000", 2)
...@@ -376,12 +374,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) { ...@@ -376,12 +374,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
func TestGetDiffProxiedToGitalySuccessfully(t *testing.T) { func TestGetDiffProxiedToGitalySuccessfully(t *testing.T) {
...@@ -447,12 +440,7 @@ func TestGetBlobProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -447,12 +440,7 @@ func TestGetBlobProxiedToGitalyInterruptedStream(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
func TestGetArchiveProxiedToGitalySuccessfully(t *testing.T) { func TestGetArchiveProxiedToGitalySuccessfully(t *testing.T) {
...@@ -521,12 +509,7 @@ func TestGetArchiveProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -521,12 +509,7 @@ func TestGetArchiveProxiedToGitalyInterruptedStream(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
func TestGetDiffProxiedToGitalyInterruptedStream(t *testing.T) { func TestGetDiffProxiedToGitalyInterruptedStream(t *testing.T) {
...@@ -553,12 +536,7 @@ func TestGetDiffProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -553,12 +536,7 @@ func TestGetDiffProxiedToGitalyInterruptedStream(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
func TestGetPatchProxiedToGitalyInterruptedStream(t *testing.T) { func TestGetPatchProxiedToGitalyInterruptedStream(t *testing.T) {
...@@ -585,12 +563,7 @@ func TestGetPatchProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -585,12 +563,7 @@ func TestGetPatchProxiedToGitalyInterruptedStream(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
func TestGetSnapshotProxiedToGitalySuccessfully(t *testing.T) { func TestGetSnapshotProxiedToGitalySuccessfully(t *testing.T) {
...@@ -634,12 +607,7 @@ func TestGetSnapshotProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -634,12 +607,7 @@ func TestGetSnapshotProxiedToGitalyInterruptedStream(t *testing.T) {
close(done) close(done)
}() }()
select { waitDone(t, done)
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
} }
func buildGetSnapshotParams(gitalyAddress string, repo *gitalypb.Repository) string { func buildGetSnapshotParams(gitalyAddress string, repo *gitalypb.Repository) string {
......
...@@ -3,12 +3,11 @@ package main ...@@ -3,12 +3,11 @@ package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp" "regexp"
"strings"
"testing" "testing"
"time" "time"
...@@ -33,27 +32,20 @@ func newProxy(url string, rt http.RoundTripper) *proxy.Proxy { ...@@ -33,27 +32,20 @@ func newProxy(url string, rt http.RoundTripper) *proxy.Proxy {
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { require.Equal(t, "POST", r.Method, "method")
t.Fatal("Expected POST request") require.Equal(t, "test", r.Header.Get("Custom-Header"), "custom header")
} require.Equal(t, testVersion, r.Header.Get("Gitlab-Workhorse"), "version header")
if r.Header.Get("Custom-Header") != "test" { require.Regexp(
t.Fatal("Missing custom header") t,
} regexp.MustCompile(`\A1`),
r.Header.Get("Gitlab-Workhorse-Proxy-Start"),
"expect Gitlab-Workhorse-Proxy-Start to start with 1",
)
if h := r.Header.Get("Gitlab-Workhorse"); h != testVersion { body, err := ioutil.ReadAll(r.Body)
t.Fatalf("Missing GitLab-Workhorse header: want %q, got %q", testVersion, h) require.NoError(t, err, "read body")
} require.Equal(t, "REQUEST", string(body), "body contents")
if h := r.Header.Get("Gitlab-Workhorse-Proxy-Start"); !strings.HasPrefix(h, "1") {
t.Fatalf("Expect Gitlab-Workhorse-Proxy-Start to start with 1, got %q", h)
}
var body bytes.Buffer
io.Copy(&body, r.Body)
if body.String() != "REQUEST" {
t.Fatal("Expected REQUEST in request body")
}
w.Header().Set("Custom-Response-Header", "test") w.Header().Set("Custom-Response-Header", "test")
w.WriteHeader(202) w.WriteHeader(202)
...@@ -61,9 +53,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -61,9 +53,7 @@ func TestProxyRequest(t *testing.T) {
}) })
httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
w := httptest.NewRecorder() w := httptest.NewRecorder()
...@@ -71,16 +61,12 @@ func TestProxyRequest(t *testing.T) { ...@@ -71,16 +61,12 @@ func TestProxyRequest(t *testing.T) {
require.Equal(t, 202, w.Code) require.Equal(t, 202, w.Code)
testhelper.RequireResponseBody(t, w, "RESPONSE") testhelper.RequireResponseBody(t, w, "RESPONSE")
if w.Header().Get("Custom-Response-Header") != "test" { require.Equal(t, "test", w.Header().Get("Custom-Response-Header"), "custom response header")
t.Fatal("Expected custom response header")
}
} }
func TestProxyError(t *testing.T) { func TestProxyError(t *testing.T) {
httpRequest, err := http.NewRequest("POST", "/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("POST", "/url/path", bytes.NewBufferString("REQUEST"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
w := httptest.NewRecorder() w := httptest.NewRecorder()
...@@ -95,9 +81,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -95,9 +81,7 @@ func TestProxyReadTimeout(t *testing.T) {
}) })
httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil) httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
rt := badgateway.NewRoundTripper(false, &http.Transport{ rt := badgateway.NewRoundTripper(false, &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
...@@ -124,9 +108,7 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -124,9 +108,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
) )
httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil) httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
newProxy(ts.URL, nil).ServeHTTP(w, httpRequest) newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
......
...@@ -73,23 +73,18 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest. ...@@ -73,23 +73,18 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.
expectSignedRequest(t, r) expectSignedRequest(t, r)
w.Header().Set("Content-Type", api.ResponseContentType) w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil { _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir)
t.Fatal(err) require.NoError(t, err)
}
return return
} }
err := r.ParseMultipartForm(100000) require.NoError(t, r.ParseMultipartForm(100000))
if err != nil {
t.Fatal(err) const nValues = 10 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload for just the upload (no metadata because we are not POSTing a valid zip file)
} require.Len(t, r.MultipartForm.Value, nValues)
nValues := 10 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues { require.Empty(t, r.MultipartForm.File, "multipart form files")
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
if extraTests != nil { if extraTests != nil {
extraTests(r) extraTests(r)
} }
...@@ -202,43 +197,35 @@ func TestBlockingRewrittenFieldsHeader(t *testing.T) { ...@@ -202,43 +197,35 @@ func TestBlockingRewrittenFieldsHeader(t *testing.T) {
{"no multipart", "text/plain", nil, false}, {"no multipart", "text/plain", nil, false},
} }
if b, c, err := multipartBodyWithFile(); err == nil { var err error
testCases[0].contentType = c testCases[0].body, testCases[0].contentType, err = multipartBodyWithFile()
testCases[0].body = b require.NoError(t, err)
} else {
t.Fatal(err)
}
for _, tc := range testCases { for _, tc := range testCases {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
h := upload.RewrittenFieldsHeader key := upload.RewrittenFieldsHeader
if _, ok := r.Header[h]; ok != tc.present { if tc.present {
t.Errorf("Expectation of presence (%v) violated", tc.present) require.Contains(t, r.Header, key)
} } else {
if r.Header.Get(h) == canary { require.NotContains(t, r.Header, key)
t.Errorf("Found canary %q in header %q", canary, h)
} }
require.NotEqual(t, canary, r.Header.Get(key), "Found canary %q in header %q", canary, key)
}) })
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
req, err := http.NewRequest("POST", ws.URL+"/something", tc.body) req, err := http.NewRequest("POST", ws.URL+"/something", tc.body)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
req.Header.Set("Content-Type", tc.contentType) req.Header.Set("Content-Type", tc.contentType)
req.Header.Set(upload.RewrittenFieldsHeader, canary) req.Header.Set(upload.RewrittenFieldsHeader, canary)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { require.NoError(t, err)
t.Error(err)
}
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("%s: expected HTTP 200, got %d", tc.desc, resp.StatusCode)
}
require.Equal(t, 200, resp.StatusCode, "status code")
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment