Commit 2959b8ac authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Nick Thomas

Use more require in internal/{upload,upstream}

parent e3246cbe
......@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
......@@ -53,9 +52,7 @@ func (a *testFormProcessor) Name() string {
func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder()
request, err := http.NewRequest("", "", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
apiResponse := &api.Response{}
preparer := &DefaultPreparer{}
opts, _, err := preparer.Prepare(apiResponse)
......@@ -67,15 +64,11 @@ func TestUploadTempPathRequirement(t *testing.T) {
func TestUploadHandlerForwardingRawData(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Fatal("Expected PATCH request")
}
require.Equal(t, "PATCH", r.Method, "method")
var body bytes.Buffer
io.Copy(&body, r.Body)
if body.String() != "REQUEST" {
t.Fatal("Expected REQUEST in request body")
}
body, err := ioutil.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, "REQUEST", string(body), "request body")
w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE")
......@@ -83,14 +76,10 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
defer ts.Close()
httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(tempPath)
response := httptest.NewRecorder()
......@@ -104,58 +93,30 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
HandleFileUploads(response, httpRequest, handler, apiResponse, nil, opts)
testhelper.RequireResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body")
}
require.Equal(t, "RESPONSE", response.Body.String(), "response body")
}
func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
var filePath string
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(tempPath)
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Fatal("Expected PUT request")
}
require.Equal(t, "PUT", r.Method, "method")
require.NoError(t, r.ParseMultipartForm(100000))
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
if r.FormValue("token") != "test" {
t.Error("Expected to receive token")
}
if r.FormValue("file.name") != "my.file" {
t.Error("Expected to receive a filename")
}
require.Empty(t, r.MultipartForm.File, "Expected to not receive any files")
require.Equal(t, "test", r.FormValue("token"), "Expected to receive token")
require.Equal(t, "my.file", r.FormValue("file.name"), "Expected to receive a filename")
filePath = r.FormValue("file.path")
if !strings.HasPrefix(filePath, tempPath) {
t.Error("Expected to the file to be in tempPath")
}
if r.FormValue("file.remote_url") != "" {
t.Error("Expected to receive empty remote_url")
}
if r.FormValue("file.remote_id") != "" {
t.Error("Expected to receive empty remote_id")
}
require.True(t, strings.HasPrefix(filePath, tempPath), "Expected to the file to be in tempPath")
if r.FormValue("file.size") != "4" {
t.Error("Expected to receive the file size")
}
require.Empty(t, r.FormValue("file.remote_url"), "Expected to receive empty remote_url")
require.Empty(t, r.FormValue("file.remote_id"), "Expected to receive empty remote_id")
require.Equal(t, "4", r.FormValue("file.size"), "Expected to receive the file size")
hashes := map[string]string{
"md5": "098f6bcd4621d373cade4e832627b4f6",
......@@ -165,14 +126,10 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
}
for algo, hash := range hashes {
if r.FormValue("file."+algo) != hash {
t.Errorf("Expected to receive file %s hash", algo)
}
require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo)
}
if valueCnt := len(r.MultipartForm.Value); valueCnt != 11 {
t.Fatal("Expected to receive exactly 11 values but got", valueCnt)
}
require.Len(t, r.MultipartForm.Value, 11, "multipart form values")
w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE")
......@@ -183,16 +140,12 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
writer := multipart.NewWriter(&buffer)
writer.WriteField("token", "test")
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
fmt.Fprint(file, "test")
writer.Close()
httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
httpRequest = httpRequest.WithContext(ctx)
......@@ -219,9 +172,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
var filePath string
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(tempPath)
tests := []struct {
......@@ -249,9 +200,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Fatal("Expected PUT request")
}
require.Equal(t, "PUT", r.Method, "method")
w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE")
......@@ -261,18 +210,14 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
fmt.Fprint(file, "test")
writer.WriteField(test.field, "value")
writer.Close()
httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", &buffer)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
httpRequest = httpRequest.WithContext(ctx)
......@@ -296,9 +241,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
func TestUploadProcessingField(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(tempPath)
var buffer bytes.Buffer
......@@ -308,9 +251,7 @@ func TestUploadProcessingField(t *testing.T) {
writer.Close()
httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
......@@ -326,9 +267,7 @@ func TestUploadProcessingField(t *testing.T) {
func TestUploadProcessingFile(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(tempPath)
_, testServer := test.StartObjectStore()
......@@ -362,16 +301,12 @@ func TestUploadProcessingFile(t *testing.T) {
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
fmt.Fprint(file, "test")
writer.Close()
httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
......@@ -392,9 +327,7 @@ func TestInvalidFileNames(t *testing.T) {
testhelper.ConfigureSecret()
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(tempPath)
for _, testCase := range []struct {
......@@ -411,16 +344,12 @@ func TestInvalidFileNames(t *testing.T) {
writer := multipart.NewWriter(buffer)
file, err := writer.CreateFormFile("file", testCase.filename)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
fmt.Fprint(file, "test")
writer.Close()
httpRequest, err := http.NewRequest("POST", "/example", buffer)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
......@@ -548,7 +477,5 @@ func waitUntilDeleted(t *testing.T, path string) {
time.Sleep(100 * time.Millisecond)
}
if !os.IsNotExist(err) {
t.Fatal("expected the file to be deleted")
}
require.True(t, os.IsNotExist(err), "expected the file to be deleted")
}
......@@ -6,6 +6,8 @@ import (
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"github.com/stretchr/testify/require"
)
func TestDevelopmentModeEnabled(t *testing.T) {
......@@ -18,9 +20,8 @@ func TestDevelopmentModeEnabled(t *testing.T) {
NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true
})).ServeHTTP(w, r)
if !executed {
t.Error("The handler should get executed")
}
require.True(t, executed, "The handler should get executed")
}
func TestDevelopmentModeDisabled(t *testing.T) {
......@@ -33,8 +34,8 @@ func TestDevelopmentModeDisabled(t *testing.T) {
NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true
})).ServeHTTP(w, r)
if executed {
t.Error("The handler should not get executed")
}
require.False(t, executed, "The handler should not get executed")
testhelper.RequireResponseCode(t, w, 404)
}
......@@ -7,10 +7,11 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"github.com/stretchr/testify/require"
)
func TestGzipEncoding(t *testing.T) {
......@@ -24,18 +25,12 @@ func TestGzipEncoding(t *testing.T) {
body := ioutil.NopCloser(&b)
req, err := http.NewRequest("POST", "http://address/test", body)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
req.Header.Set("Content-Encoding", "gzip")
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if _, ok := r.Body.(*gzip.Reader); !ok {
t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body))
}
if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted")
}
require.IsType(t, &gzip.Reader{}, r.Body, "body type")
require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted")
})).ServeHTTP(resp, req)
testhelper.RequireResponseCode(t, resp, 200)
......@@ -48,18 +43,12 @@ func TestNoEncoding(t *testing.T) {
body := ioutil.NopCloser(&b)
req, err := http.NewRequest("POST", "http://address/test", body)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
req.Header.Set("Content-Encoding", "")
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if r.Body != body {
t.Fatal("Expected the same body")
}
if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted")
}
require.Equal(t, body, r.Body, "Expected the same body")
require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted")
})).ServeHTTP(resp, req)
testhelper.RequireResponseCode(t, resp, 200)
......@@ -69,9 +58,7 @@ func TestInvalidEncoding(t *testing.T) {
resp := httptest.NewRecorder()
req, err := http.NewRequest("POST", "http://address/test", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
req.Header.Set("Content-Encoding", "application/unknown")
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
......
package roundtripper
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestMustParseAddress(t *testing.T) {
......@@ -10,26 +13,27 @@ func TestMustParseAddress(t *testing.T) {
{"[::1]:23", "http", "::1:23"},
{"4.5.6.7", "http", "4.5.6.7:http"},
}
for _, example := range successExamples {
result := mustParseAddress(example.address, example.scheme)
if example.expected != result {
t.Errorf("expected %q, got %q", example.expected, result)
}
for i, example := range successExamples {
t.Run(strconv.Itoa(i), func(t *testing.T) {
require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme))
})
}
}
func TestMustParseAddressPanic(t *testing.T) {
panicExamples := []struct{ address, scheme string }{
{"1.2.3.4", ""},
{"1.2.3.4", "https"},
}
for _, panicExample := range panicExamples {
func() {
for i, panicExample := range panicExamples {
t.Run(strconv.Itoa(i), func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic for %v but none occurred", panicExample)
t.Fatal("expected panic")
}
}()
t.Log(mustParseAddress(panicExample.address, panicExample.scheme))
}()
mustParseAddress(panicExample.address, panicExample.scheme)
})
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment