Commit 839a26be authored by Stan Hu's avatar Stan Hu Committed by Jacob Vosmaer

Fix correlation IDs not being propagated in preauth check

Previously `info/refs`, `upload-pack`, `receive-pack`, and anything that
needed to make an authorization check with Rails would not forward along
the correlation ID injected in the request context. We solve this by
creating the HTTP request with the same context. This adds some
additional memory allocations for each outbound request.

Relates to https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/293
parent 8a20c4af
package main
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"gitlab.com/gitlab-org/labkit/correlation"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require"
......@@ -24,12 +27,13 @@ func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
if ts == nil {
ts = testAuthServer(url, nil, returnCode, apiResponse)
ts = testAuthServer(t, url, nil, returnCode, apiResponse)
defer ts.Close()
}
// Create http request
httpRequest, err := http.NewRequest("GET", "/address", nil)
ctx := correlation.ContextWithCorrelation(context.Background(), "12345678")
httpRequest, err := http.NewRequestWithContext(ctx, "GET", "/address", nil)
require.NoError(t, err)
parsedURL := helper.URLMustParse(ts.URL)
testhelper.ConfigureSecret()
......
---
title: Fix correlation IDs not being propagated in preauth check
merge_request: 607
author:
type: fixed
......@@ -42,7 +42,7 @@ func TestChannelHappyPath(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
serverConns, clientURL, close := wireupChannel(test.channelPath, nil, "channel.k8s.io")
serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
......@@ -70,7 +70,7 @@ func TestChannelHappyPath(t *testing.T) {
}
func TestChannelBadTLS(t *testing.T) {
_, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io")
_, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io")
defer close()
_, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
......@@ -78,7 +78,7 @@ func TestChannelBadTLS(t *testing.T) {
}
func TestChannelSessionTimeout(t *testing.T) {
serverConns, clientURL, close := wireupChannel(envTerminalPath, timeout, "channel.k8s.io")
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
......@@ -96,7 +96,7 @@ func TestChannelSessionTimeout(t *testing.T) {
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header)
hdr.Set("Random-Header", "Value")
serverConns, clientURL, close := wireupChannel(envTerminalPath, setHeader(hdr), "channel.k8s.io")
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
......@@ -109,7 +109,7 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
}
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
serverConns, clientURL, close := wireupChannel(envTerminalPath, nil, "channel.k8s.io")
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io")
defer close()
hdr := make(http.Header)
......@@ -127,13 +127,13 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
}
func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
serverConns, remote := startWebsocketServer(subprotocols...)
authResponse := channelOkBody(remote, nil, subprotocols...)
if modifier != nil {
modifier(authResponse)
}
upstream := testAuthServer(nil, nil, 200, authResponse)
upstream := testAuthServer(t, nil, nil, 200, authResponse)
workhorse := startWorkhorseServer(upstream.URL)
return serverConns, websocketURL(workhorse.URL, channelPath), func() {
......
......@@ -90,7 +90,7 @@ func TestAllowedClone(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
......@@ -114,7 +114,7 @@ func TestAllowedShallowClone(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
......@@ -138,7 +138,7 @@ func TestAllowedPush(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare the test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
......
......@@ -43,7 +43,7 @@ func TestFailedCloneNoGitaly(t *testing.T) {
}
// Prepare test server and backend
ts := testAuthServer(nil, nil, 200, authBody)
ts := testAuthServer(t, nil, nil, 200, authBody)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
......@@ -95,7 +95,7 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
t.Run(fmt.Sprintf("ShowAllRefs=%v,gitRpc=%v", tc.showAllRefs, tc.gitRpc), func(t *testing.T) {
apiResponse.ShowAllRefs = tc.showAllRefs
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
......@@ -147,7 +147,7 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) {
gitalyAddress := "unix:" + socketPath
apiResponse.GitalyServer.Address = gitalyAddress
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
......@@ -187,7 +187,7 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) {
apiResponse.GitalyServer.Address = "unix:" + socketPath
apiResponse.GitConfigOptions = []string{"git-config-hello=world"}
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
......@@ -232,7 +232,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
......@@ -282,7 +282,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
......@@ -349,7 +349,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) {
defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse)
ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
......
......@@ -456,6 +456,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
......@@ -837,11 +838,11 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k=
honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k=
honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
......@@ -188,6 +188,8 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error
Header: helper.HeaderClone(r.Header),
}
authReq = authReq.WithContext(r.Context())
// Clean some headers when issuing a new request without body
authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding")
......
......@@ -63,7 +63,7 @@ func TestDeniedClone(t *testing.T) {
require.NoError(t, os.RemoveAll(scratchDir))
// Prepare test server and backend
ts := testAuthServer(nil, nil, 403, "Access denied")
ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
......@@ -77,7 +77,7 @@ func TestDeniedClone(t *testing.T) {
func TestDeniedPush(t *testing.T) {
// Prepare the test server and backend
ts := testAuthServer(nil, nil, 403, "Access denied")
ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
......@@ -594,8 +594,10 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano())
}
func testAuthServer(url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server {
func testAuthServer(t *testing.T, url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server {
return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
require.NotEmpty(t, r.Header.Get("X-Request-Id"))
w.Header().Set("Content-Type", api.ResponseContentType)
logEntry := log.WithFields(log.Fields{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment