diff --git a/internal/badgateway/roundtripper.go b/internal/badgateway/roundtripper.go new file mode 100644 index 0000000000000000000000000000000000000000..3d5c7b337ef3e616b9b01d3f0a66926b69ffbd30 --- /dev/null +++ b/internal/badgateway/roundtripper.go @@ -0,0 +1,84 @@ +package badgateway + +import ( + "../helper" + "bytes" + "fmt" + "io/ioutil" + "net" + "net/http" + "sync" + "time" +) + +// Values from http.DefaultTransport +var DefaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, +} + +var DefaultTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport + Dial: DefaultDialer.Dial, // from http.DefaultTransport + ResponseHeaderTimeout: time.Minute, // custom + TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport +} + +type RoundTripper struct { + Socket string + ResponseHeaderTimeout time.Duration + Transport *http.Transport + configureRoundTripperOnce sync.Once +} + +func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { + t.configureRoundTripperOnce.Do(t.configureRoundTripper) + + res, err = t.Transport.RoundTrip(r) + + // httputil.ReverseProxy translates all errors from this + // RoundTrip function into 500 errors. But the most likely error + // is that the Rails app is not responding, in which case users + // and administrators expect to see a 502 error. To show 502s + // instead of 500s we catch the RoundTrip error here and inject a + // 502 response. + if err != nil { + helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err)) + + res = &http.Response{ + StatusCode: http.StatusBadGateway, + Status: http.StatusText(http.StatusBadGateway), + + Request: r, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Proto: r.Proto, + Header: make(http.Header), + Trailer: make(http.Header), + Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())), + } + res.Header.Set("Content-Type", "text/plain") + err = nil + } + return +} + +func (t *RoundTripper) configureRoundTripper() { + if t.Transport != nil { + return + } + + tr := *DefaultTransport + + if t.ResponseHeaderTimeout != 0 { + tr.ResponseHeaderTimeout = t.ResponseHeaderTimeout + } + + if t.Socket != "" { + tr.Dial = func(_, _ string) (net.Conn, error) { + return DefaultDialer.Dial("unix", t.Socket) + } + } + + t.Transport = &tr +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 19e3e61fd72ba4ef08b1569af996cfaaba83aa84..c07f342fa91d17280bfd4000a051e75c0374c228 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -1,10 +1,7 @@ package proxy import ( - "../helper" - "bytes" - "fmt" - "io/ioutil" + "../badgateway" "net/http" "net/http/httputil" "net/url" @@ -14,55 +11,23 @@ import ( type Proxy struct { URL *url.URL Version string - Transport http.RoundTripper + RoundTripper *badgateway.RoundTripper _reverseProxy *httputil.ReverseProxy configureReverseProxyOnce sync.Once } func (p *Proxy) reverseProxy() *httputil.ReverseProxy { - p.configureReverseProxyOnce.Do(p.configureReverseProxy) - return p._reverseProxy -} - -func (p *Proxy) configureReverseProxy() { - u := *p.URL // Make a copy of p.URL - u.Path = "" - p._reverseProxy = httputil.NewSingleHostReverseProxy(&u) - p._reverseProxy.Transport = p.Transport -} - -type RoundTripper struct { - Transport http.RoundTripper -} - -func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { - res, err = rt.Transport.RoundTrip(r) - - // httputil.ReverseProxy translates all errors from this - // RoundTrip function into 500 errors. But the most likely error - // is that the Rails app is not responding, in which case users - // and administrators expect to see a 502 error. To show 502s - // instead of 500s we catch the RoundTrip error here and inject a - // 502 response. - if err != nil { - helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err)) - - res = &http.Response{ - StatusCode: http.StatusBadGateway, - Status: http.StatusText(http.StatusBadGateway), - - Request: r, - ProtoMajor: r.ProtoMajor, - ProtoMinor: r.ProtoMinor, - Proto: r.Proto, - Header: make(http.Header), - Trailer: make(http.Header), - Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())), + p.configureReverseProxyOnce.Do(func() { + u := *p.URL // Make a copy of p.URL + u.Path = "" + p._reverseProxy = httputil.NewSingleHostReverseProxy(&u) + if p.RoundTripper != nil { + p._reverseProxy.Transport = p.RoundTripper + } else { + p._reverseProxy.Transport = &badgateway.RoundTripper{} } - res.Header.Set("Content-Type", "text/plain") - err = nil - } - return + }) + return p._reverseProxy } func HeaderClone(h http.Header) http.Header { diff --git a/internal/upstream/routes.go b/internal/upstream/routes.go index e6ec30962091d1cafcc099e15a97f8cf311a9120..b0c84d9c34ef3a5a4664d3fbddc6c18ef063b716 100644 --- a/internal/upstream/routes.go +++ b/internal/upstream/routes.go @@ -3,6 +3,7 @@ package upstream import ( "../git" "../lfs" + pr "../proxy" "../staticpages" "../upload" "net/http" @@ -34,12 +35,13 @@ func (u *Upstream) Routes() []route { func (u *Upstream) configureRoutes() { static := &staticpages.Static{u.DocumentRoot} + proxy := &pr.Proxy{URL: u.Backend, Version: u.Version, RoundTripper: u.RoundTripper()} u.routes = []route{ // Git Clone route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API())}, route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))}, route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))}, - route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), u.Proxy())}, + route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), proxy)}, // Repository Archive route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API())}, @@ -56,17 +58,17 @@ func (u *Upstream) configureRoutes() { route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())}, // CI Artifacts API - route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), u.Proxy()))}, + route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), proxy))}, // Explicitly proxy API requests - route{"", regexp.MustCompile(apiPattern), u.Proxy()}, - route{"", regexp.MustCompile(ciAPIPattern), u.Proxy()}, + route{"", regexp.MustCompile(apiPattern), proxy}, + route{"", regexp.MustCompile(ciAPIPattern), proxy}, // Serve assets route{"", regexp.MustCompile(`^/assets/`), static.ServeExisting(u.URLPrefix(), staticpages.CacheExpireMax, NotFoundUnless(u.DevelopmentMode, - u.Proxy(), + proxy, ), ), }, @@ -76,7 +78,7 @@ func (u *Upstream) configureRoutes() { static.ServeExisting(u.URLPrefix(), staticpages.CacheDisabled, static.DeployPage( static.ErrorPages( - u.Proxy(), + proxy, ), ), ), diff --git a/internal/upstream/transport.go b/internal/upstream/transport.go deleted file mode 100644 index 34e36a0a0c6ea7d1642e63afcd4016fbf6e3a801..0000000000000000000000000000000000000000 --- a/internal/upstream/transport.go +++ /dev/null @@ -1,42 +0,0 @@ -package upstream - -import ( - "../proxy" - "net" - "net/http" - "time" -) - -// Values from http.DefaultTransport -var DefaultDialer = &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, -} - -var DefaultTransport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport - Dial: DefaultDialer.Dial, // from http.DefaultTransport - ResponseHeaderTimeout: time.Minute, // custom - TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport -} - -func (u *Upstream) Transport() http.RoundTripper { - u.configureTransportOnce.Do(u.configureTransport) - return u.transport -} - -func (u *Upstream) configureTransport() { - t := *DefaultTransport - - if u.ResponseHeaderTimeout != 0 { - t.ResponseHeaderTimeout = u.ResponseHeaderTimeout - } - - if u.Socket != "" { - t.Dial = func(_, _ string) (net.Conn, error) { - return DefaultDialer.Dial("unix", u.Socket) - } - } - - u.transport = &proxy.RoundTripper{&t} -} diff --git a/internal/upstream/upstream.go b/internal/upstream/upstream.go index 1dbd8263b3f08a07e644904565a91aa928cbaba4..72b87b9bc2ac51ecfb4ae82d23a5a2c51f81b98c 100644 --- a/internal/upstream/upstream.go +++ b/internal/upstream/upstream.go @@ -8,6 +8,7 @@ package upstream import ( "../api" + "../badgateway" "../helper" "../proxy" "../staticpages" @@ -42,22 +43,13 @@ type Upstream struct { routes []route configureRoutesOnce sync.Once - transport http.RoundTripper - configureTransportOnce sync.Once + roundtripper *badgateway.RoundTripper + configureRoundTripperOnce sync.Once _static *staticpages.Static configureStaticOnce sync.Once } -func (u *Upstream) Proxy() *proxy.Proxy { - u.configureProxyOnce.Do(u.configureProxy) - return u._proxy -} - -func (u *Upstream) configureProxy() { - u._proxy = &proxy.Proxy{URL: u.Backend, Transport: u.Transport(), Version: u.Version} -} - func (u *Upstream) API() *api.API { u.configureAPIOnce.Do(u.configureAPI) return u._api @@ -65,7 +57,7 @@ func (u *Upstream) API() *api.API { func (u *Upstream) configureAPI() { u._api = &api.API{ - Client: &http.Client{Transport: u.Transport()}, + Client: &http.Client{Transport: u.RoundTripper()}, URL: u.Backend, Version: u.Version, } @@ -87,12 +79,15 @@ func (u *Upstream) configureURLPrefix() { u.urlPrefix = urlprefix.Prefix(relativeURLRoot) } -// func (u *Upstream) Static() *static.Static { -// u.configureStaticOnce.Do(func() { -// u._static = &static.Static{u.DocumentRoot} -// }) -// return u._static -// } +func (u *Upstream) RoundTripper() *badgateway.RoundTripper { + u.configureRoundTripperOnce.Do(func() { + u.roundtripper = &badgateway.RoundTripper{ + Socket: u.Socket, + ResponseHeaderTimeout: u.ResponseHeaderTimeout, + } + }) + return u.roundtripper +} func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { w := newLoggingResponseWriter(ow) diff --git a/main.go b/main.go index 2f36a7008c9e655220168f2506cb86773586ce38..8fa94c29f60c08364589d3e175266798607a3e95 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type. package main import ( + "./internal/badgateway" "./internal/upstream" "flag" "fmt" @@ -36,7 +37,7 @@ var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authenticatio var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var documentRoot = flag.String("documentRoot", "public", "Path to static files content") -var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", upstream.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request") +var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", badgateway.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request") var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app") func main() { diff --git a/proxy_test.go b/proxy_test.go index ca2be4841bf12de7305871762b5500b83ab9b219..9658e63852be540f4c352f3b4fbf943b450b9dc9 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1,9 +1,9 @@ package main import ( + "./internal/badgateway" "./internal/helper" "./internal/proxy" - "./internal/upstream" "bytes" "fmt" "io" @@ -15,8 +15,8 @@ import ( "time" ) -func newUpstream(url string) *upstream.Upstream { - return &upstream.Upstream{Backend: helper.URLMustParse(url), Version: "123"} +func newProxy(url string) *proxy.Proxy { + return &proxy.Proxy{URL: helper.URLMustParse(url), Version: "123"} } func TestProxyRequest(t *testing.T) { @@ -46,9 +46,8 @@ func TestProxyRequest(t *testing.T) { } httpRequest.Header.Set("Custom-Header", "test") - u := newUpstream(ts.URL) w := httptest.NewRecorder() - u.Proxy().ServeHTTP(w, httpRequest) + newProxy(ts.URL).ServeHTTP(w, httpRequest) helper.AssertResponseCode(t, w, 202) helper.AssertResponseBody(t, w, "RESPONSE") @@ -64,9 +63,8 @@ func TestProxyError(t *testing.T) { } httpRequest.Header.Set("Custom-Header", "test") - u := newUpstream("http://localhost:655575/") w := httptest.NewRecorder() - u.Proxy().ServeHTTP(w, httpRequest) + newProxy("http://localhost:655575/").ServeHTTP(w, httpRequest) helper.AssertResponseCode(t, w, 502) helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575") } @@ -81,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) { t.Fatal(err) } - transport := &proxy.RoundTripper{ - &http.Transport{ + rt := &badgateway.RoundTripper{ + Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, Dial: (&net.Dialer{ Timeout: 30 * time.Second, @@ -93,8 +91,8 @@ func TestProxyReadTimeout(t *testing.T) { }, } - p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"} - + p := newProxy(ts.URL) + p.RoundTripper = rt w := httptest.NewRecorder() p.ServeHTTP(w, httpRequest) helper.AssertResponseCode(t, w, 502) @@ -113,10 +111,8 @@ func TestProxyHandlerTimeout(t *testing.T) { t.Fatal(err) } - u := newUpstream(ts.URL) - w := httptest.NewRecorder() - u.Proxy().ServeHTTP(w, httpRequest) + newProxy(ts.URL).ServeHTTP(w, httpRequest) helper.AssertResponseCode(t, w, 503) helper.AssertResponseBody(t, w, "Request took too long") }