Commit 0bae36ad authored by Jacob Vosmaer's avatar Jacob Vosmaer

Hard-code backend dialer for TCP too

Workhorse usually connects to Rails over a Unix socket. This makes it
impossible to accidentally follow redirects to another host. This
change applies the same strictness when connecting to Rails over TCP.
parent 7e6fdf11
......@@ -28,7 +28,8 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil {
t.Fatal(err)
}
a := api.NewAPI(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper)
parsedURL := helper.URLMustParse(ts.URL)
a := api.NewAPI(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
response := httptest.NewRecorder()
a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
......
......@@ -91,8 +91,10 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h
}
httpRequest.Header.Set("Content-Type", contentType)
response := httptest.NewRecorder()
apiClient := api.NewAPI(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper)
proxyClient := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper)
parsedURL := helper.URLMustParse(ts.URL)
roundTripper := badgateway.TestRoundTripper(parsedURL)
apiClient := api.NewAPI(parsedURL, "123", roundTripper)
proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper)
UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest)
return response
}
......
......@@ -6,6 +6,7 @@ import (
"io/ioutil"
"net"
"net/http"
"net/url"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
......@@ -23,24 +24,47 @@ var DefaultTransport = &http.Transport{
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
var TestRoundTripper = NewRoundTripper("", 0)
type RoundTripper struct {
Transport *http.Transport
}
func NewRoundTripper(socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
func TestRoundTripper(backend *url.URL) *RoundTripper {
return NewRoundTripper(backend, "", 0)
}
func NewRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
tr := *DefaultTransport
tr.ResponseHeaderTimeout = proxyHeadersTimeout
if socket != "" {
if backend != nil && socket == "" {
address := mustParseAddress(backend.Host, backend.Scheme)
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("tcp", address)
}
} else if socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", socket)
}
} else {
panic("backend is nil and socket is empty")
}
return &RoundTripper{Transport: &tr}
}
func mustParseAddress(address, scheme string) string {
if host, port, err := net.SplitHostPort(address); err == nil {
return host + ":" + port
}
address = fmt.Sprintf("%s:%s", address, scheme)
if host, port, err := net.SplitHostPort(address); err == nil {
return host + ":" + port
}
panic("could not parse host/port from addres / scheme")
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = t.Transport.RoundTrip(r)
......
......@@ -76,7 +76,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder()
handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper)
handler := newProxy(ts.URL)
HandleFileUploads(response, httpRequest, handler, tempPath, nil)
testhelper.AssertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" {
......@@ -150,7 +150,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper)
handler := newProxy(ts.URL)
HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 202)
......@@ -210,3 +210,8 @@ func TestUploadProcessingFile(t *testing.T) {
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 500)
}
func newProxy(url string) *proxy.Proxy {
parsedURL := helper.URLMustParse(url)
return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
}
......@@ -37,11 +37,11 @@ func NewUpstream(backend *url.URL, socket string, version string, documentRoot s
Version: version,
DocumentRoot: documentRoot,
DevelopmentMode: developmentMode,
RoundTripper: badgateway.NewRoundTripper(socket, proxyHeadersTimeout),
}
if backend == nil {
up.Backend = DefaultBackend
}
up.RoundTripper = badgateway.NewRoundTripper(up.Backend, socket, proxyHeadersTimeout)
up.configureURLPrefix()
up.configureRoutes()
return &up
......
......@@ -21,10 +21,11 @@ import (
const testVersion = "123"
func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy {
parsedURL := helper.URLMustParse(url)
if rt == nil {
rt = badgateway.TestRoundTripper
rt = badgateway.TestRoundTripper(parsedURL)
}
return proxy.NewProxy(helper.URLMustParse(url), testVersion, rt)
return proxy.NewProxy(parsedURL, testVersion, rt)
}
func TestProxyRequest(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment