Commit 839a26be authored by Stan Hu's avatar Stan Hu Committed by Jacob Vosmaer

Fix correlation IDs not being propagated in preauth check

Previously `info/refs`, `upload-pack`, `receive-pack`, and anything that
needed to make an authorization check with Rails would not forward along
the correlation ID injected in the request context. We solve this by
creating the HTTP request with the same context. This adds some
additional memory allocations for each outbound request.

Relates to https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/293
parent 8a20c4af
package main package main
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp" "regexp"
"testing" "testing"
"gitlab.com/gitlab-org/labkit/correlation"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -24,12 +27,13 @@ func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) { ...@@ -24,12 +27,13 @@ func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder { func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
if ts == nil { if ts == nil {
ts = testAuthServer(url, nil, returnCode, apiResponse) ts = testAuthServer(t, url, nil, returnCode, apiResponse)
defer ts.Close() defer ts.Close()
} }
// Create http request // Create http request
httpRequest, err := http.NewRequest("GET", "/address", nil) ctx := correlation.ContextWithCorrelation(context.Background(), "12345678")
httpRequest, err := http.NewRequestWithContext(ctx, "GET", "/address", nil)
require.NoError(t, err) require.NoError(t, err)
parsedURL := helper.URLMustParse(ts.URL) parsedURL := helper.URLMustParse(ts.URL)
testhelper.ConfigureSecret() testhelper.ConfigureSecret()
......
---
title: Fix correlation IDs not being propagated in preauth check
merge_request: 607
author:
type: fixed
...@@ -42,7 +42,7 @@ func TestChannelHappyPath(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestChannelHappyPath(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
serverConns, clientURL, close := wireupChannel(test.channelPath, nil, "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -70,7 +70,7 @@ func TestChannelHappyPath(t *testing.T) { ...@@ -70,7 +70,7 @@ func TestChannelHappyPath(t *testing.T) {
} }
func TestChannelBadTLS(t *testing.T) { func TestChannelBadTLS(t *testing.T) {
_, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io") _, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io")
defer close() defer close()
_, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") _, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -78,7 +78,7 @@ func TestChannelBadTLS(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestChannelBadTLS(t *testing.T) {
} }
func TestChannelSessionTimeout(t *testing.T) { func TestChannelSessionTimeout(t *testing.T) {
serverConns, clientURL, close := wireupChannel(envTerminalPath, timeout, "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -96,7 +96,7 @@ func TestChannelSessionTimeout(t *testing.T) { ...@@ -96,7 +96,7 @@ func TestChannelSessionTimeout(t *testing.T) {
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header) hdr := make(http.Header)
hdr.Set("Random-Header", "Value") hdr.Set("Random-Header", "Value")
serverConns, clientURL, close := wireupChannel(envTerminalPath, setHeader(hdr), "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -109,7 +109,7 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { ...@@ -109,7 +109,7 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
} }
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
serverConns, clientURL, close := wireupChannel(envTerminalPath, nil, "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io")
defer close() defer close()
hdr := make(http.Header) hdr := make(http.Header)
...@@ -127,13 +127,13 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { ...@@ -127,13 +127,13 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote") require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
} }
func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) { func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
serverConns, remote := startWebsocketServer(subprotocols...) serverConns, remote := startWebsocketServer(subprotocols...)
authResponse := channelOkBody(remote, nil, subprotocols...) authResponse := channelOkBody(remote, nil, subprotocols...)
if modifier != nil { if modifier != nil {
modifier(authResponse) modifier(authResponse)
} }
upstream := testAuthServer(nil, nil, 200, authResponse) upstream := testAuthServer(t, nil, nil, 200, authResponse)
workhorse := startWorkhorseServer(upstream.URL) workhorse := startWorkhorseServer(upstream.URL)
return serverConns, websocketURL(workhorse.URL, channelPath), func() { return serverConns, websocketURL(workhorse.URL, channelPath), func() {
......
...@@ -90,7 +90,7 @@ func TestAllowedClone(t *testing.T) { ...@@ -90,7 +90,7 @@ func TestAllowedClone(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse)) require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -114,7 +114,7 @@ func TestAllowedShallowClone(t *testing.T) { ...@@ -114,7 +114,7 @@ func TestAllowedShallowClone(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse)) require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -138,7 +138,7 @@ func TestAllowedPush(t *testing.T) { ...@@ -138,7 +138,7 @@ func TestAllowedPush(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse)) require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare the test server and backend // Prepare the test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
......
...@@ -43,7 +43,7 @@ func TestFailedCloneNoGitaly(t *testing.T) { ...@@ -43,7 +43,7 @@ func TestFailedCloneNoGitaly(t *testing.T) {
} }
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 200, authBody) ts := testAuthServer(t, nil, nil, 200, authBody)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -95,7 +95,7 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
t.Run(fmt.Sprintf("ShowAllRefs=%v,gitRpc=%v", tc.showAllRefs, tc.gitRpc), func(t *testing.T) { t.Run(fmt.Sprintf("ShowAllRefs=%v,gitRpc=%v", tc.showAllRefs, tc.gitRpc), func(t *testing.T) {
apiResponse.ShowAllRefs = tc.showAllRefs apiResponse.ShowAllRefs = tc.showAllRefs
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -147,7 +147,7 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -147,7 +147,7 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) {
gitalyAddress := "unix:" + socketPath gitalyAddress := "unix:" + socketPath
apiResponse.GitalyServer.Address = gitalyAddress apiResponse.GitalyServer.Address = gitalyAddress
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -187,7 +187,7 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -187,7 +187,7 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) {
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
apiResponse.GitConfigOptions = []string{"git-config-hello=world"} apiResponse.GitConfigOptions = []string{"git-config-hello=world"}
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -232,7 +232,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) { ...@@ -232,7 +232,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
defer gitalyServer.GracefulStop() defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -282,7 +282,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -282,7 +282,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
defer gitalyServer.GracefulStop() defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -349,7 +349,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) { ...@@ -349,7 +349,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) {
defer gitalyServer.GracefulStop() defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
......
...@@ -456,6 +456,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn ...@@ -456,6 +456,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
...@@ -837,11 +838,11 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh ...@@ -837,11 +838,11 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k=
honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k=
honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
...@@ -188,6 +188,8 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error ...@@ -188,6 +188,8 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error
Header: helper.HeaderClone(r.Header), Header: helper.HeaderClone(r.Header),
} }
authReq = authReq.WithContext(r.Context())
// Clean some headers when issuing a new request without body // Clean some headers when issuing a new request without body
authReq.Header.Del("Content-Type") authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding") authReq.Header.Del("Content-Encoding")
......
...@@ -63,7 +63,7 @@ func TestDeniedClone(t *testing.T) { ...@@ -63,7 +63,7 @@ func TestDeniedClone(t *testing.T) {
require.NoError(t, os.RemoveAll(scratchDir)) require.NoError(t, os.RemoveAll(scratchDir))
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 403, "Access denied") ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -77,7 +77,7 @@ func TestDeniedClone(t *testing.T) { ...@@ -77,7 +77,7 @@ func TestDeniedClone(t *testing.T) {
func TestDeniedPush(t *testing.T) { func TestDeniedPush(t *testing.T) {
// Prepare the test server and backend // Prepare the test server and backend
ts := testAuthServer(nil, nil, 403, "Access denied") ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -594,8 +594,10 @@ func newBranch() string { ...@@ -594,8 +594,10 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano()) return fmt.Sprintf("branch-%d", time.Now().UnixNano())
} }
func testAuthServer(url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server { func testAuthServer(t *testing.T, url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server {
return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) { return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
require.NotEmpty(t, r.Header.Get("X-Request-Id"))
w.Header().Set("Content-Type", api.ResponseContentType) w.Header().Set("Content-Type", api.ResponseContentType)
logEntry := log.WithFields(log.Fields{ logEntry := log.WithFields(log.Fields{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment