Commit 57f989e2 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'cat-workhorse-proxy-host-rewrite' into 'master'

Force Host header rewrite in Workhorse for Geo proxying

See merge request gitlab-org/gitlab!83550
parents a93b446b 131dd305
...@@ -19,6 +19,7 @@ type Proxy struct { ...@@ -19,6 +19,7 @@ type Proxy struct {
reverseProxy *httputil.ReverseProxy reverseProxy *httputil.ReverseProxy
AllowResponseBuffering bool AllowResponseBuffering bool
customHeaders map[string]string customHeaders map[string]string
forceTargetHostHeader bool
} }
func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) { func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) {
...@@ -27,6 +28,12 @@ func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) { ...@@ -27,6 +28,12 @@ func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) {
} }
} }
func WithForcedTargetHostHeader() func(*Proxy) {
return func(proxy *Proxy) {
proxy.forceTargetHostHeader = true
}
}
func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, options ...func(*Proxy)) *Proxy { func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, options ...func(*Proxy)) *Proxy {
p := Proxy{Version: version, AllowResponseBuffering: true, customHeaders: make(map[string]string)} p := Proxy{Version: version, AllowResponseBuffering: true, customHeaders: make(map[string]string)}
...@@ -43,6 +50,17 @@ func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, op ...@@ -43,6 +50,17 @@ func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, op
option(&p) option(&p)
} }
if p.forceTargetHostHeader {
// because of https://github.com/golang/go/issues/28168, the
// upstream won't receive the expected Host header unless this
// is forced in the Director func here
previousDirector := p.reverseProxy.Director
p.reverseProxy.Director = func(request *http.Request) {
previousDirector(request)
request.Host = request.URL.Host
}
}
return &p return &p
} }
......
...@@ -243,6 +243,7 @@ func (u *upstream) updateGeoProxyFields(geoProxyURL *url.URL) { ...@@ -243,6 +243,7 @@ func (u *upstream) updateGeoProxyFields(geoProxyURL *url.URL) {
u.Version, u.Version,
geoProxyRoundTripper, geoProxyRoundTripper,
proxypkg.WithCustomHeaders(geoProxyWorkhorseHeaders), proxypkg.WithCustomHeaders(geoProxyWorkhorseHeaders),
proxypkg.WithForcedTargetHostHeader(),
) )
u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream) u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream)
u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy()) u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy())
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"regexp" "regexp"
"testing" "testing"
"time" "time"
...@@ -31,10 +32,15 @@ func newProxy(url string, rt http.RoundTripper, opts ...func(*proxy.Proxy)) *pro ...@@ -31,10 +32,15 @@ func newProxy(url string, rt http.RoundTripper, opts ...func(*proxy.Proxy)) *pro
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { inboundURL, err := url.Parse("https://explicitly.set.host/url/path")
require.NoError(t, err, "parse inbound url")
urlRegexp := regexp.MustCompile(fmt.Sprintf(`%s\z`, inboundURL.Path))
ts := testhelper.TestServerWithHandler(urlRegexp, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "POST", r.Method, "method") require.Equal(t, "POST", r.Method, "method")
require.Equal(t, "test", r.Header.Get("Custom-Header"), "custom header") require.Equal(t, "test", r.Header.Get("Custom-Header"), "custom header")
require.Equal(t, testVersion, r.Header.Get("Gitlab-Workhorse"), "version header") require.Equal(t, testVersion, r.Header.Get("Gitlab-Workhorse"), "version header")
require.Equal(t, inboundURL.Host, r.Host, "sent host header")
require.Regexp( require.Regexp(
t, t,
...@@ -52,7 +58,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -52,7 +58,7 @@ func TestProxyRequest(t *testing.T) {
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
}) })
httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("POST", inboundURL.String(), bytes.NewBufferString("REQUEST"))
require.NoError(t, err) require.NoError(t, err)
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
...@@ -64,6 +70,30 @@ func TestProxyRequest(t *testing.T) { ...@@ -64,6 +70,30 @@ func TestProxyRequest(t *testing.T) {
require.Equal(t, "test", w.Header().Get("Custom-Response-Header"), "custom response header") require.Equal(t, "test", w.Header().Get("Custom-Response-Header"), "custom response header")
} }
func TestProxyWithForcedTargetHostHeader(t *testing.T) {
var tsUrl *url.URL
inboundURL, err := url.Parse("https://explicitly.set.host/url/path")
require.NoError(t, err, "parse upstream url")
urlRegexp := regexp.MustCompile(fmt.Sprintf(`%s\z`, inboundURL.Path))
ts := testhelper.TestServerWithHandler(urlRegexp, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, tsUrl.Host, r.Host, "upstream host header")
_, err := w.Write([]byte(`ok`))
require.NoError(t, err, "write ok response")
})
tsUrl, err = url.Parse(ts.URL)
require.NoError(t, err, "parse testserver URL")
httpRequest, err := http.NewRequest("POST", inboundURL.String(), nil)
require.NoError(t, err)
w := httptest.NewRecorder()
testProxy := newProxy(ts.URL, nil, proxy.WithForcedTargetHostHeader())
testProxy.ServeHTTP(w, httpRequest)
testhelper.RequireResponseBody(t, w, "ok")
}
func TestProxyWithCustomHeaders(t *testing.T) { func TestProxyWithCustomHeaders(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "value", r.Header.Get("Custom-Header"), "custom proxy header") require.Equal(t, "value", r.Header.Get("Custom-Header"), "custom proxy header")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment