Commit 09144383 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'sh-fix-correlation-id-for-preauth' into 'master'

Fix correlation IDs not being propagated in preauth check

See merge request gitlab-org/gitlab-workhorse!607
parents 8cc8b774 839a26be
package main package main
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp" "regexp"
"testing" "testing"
"gitlab.com/gitlab-org/labkit/correlation"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -24,12 +27,13 @@ func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) { ...@@ -24,12 +27,13 @@ func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder { func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
if ts == nil { if ts == nil {
ts = testAuthServer(url, nil, returnCode, apiResponse) ts = testAuthServer(t, url, nil, returnCode, apiResponse)
defer ts.Close() defer ts.Close()
} }
// Create http request // Create http request
httpRequest, err := http.NewRequest("GET", "/address", nil) ctx := correlation.ContextWithCorrelation(context.Background(), "12345678")
httpRequest, err := http.NewRequestWithContext(ctx, "GET", "/address", nil)
require.NoError(t, err) require.NoError(t, err)
parsedURL := helper.URLMustParse(ts.URL) parsedURL := helper.URLMustParse(ts.URL)
testhelper.ConfigureSecret() testhelper.ConfigureSecret()
......
---
title: Fix correlation IDs not being propagated in preauth check
merge_request: 607
author:
type: fixed
...@@ -42,7 +42,7 @@ func TestChannelHappyPath(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestChannelHappyPath(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
serverConns, clientURL, close := wireupChannel(test.channelPath, nil, "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -70,7 +70,7 @@ func TestChannelHappyPath(t *testing.T) { ...@@ -70,7 +70,7 @@ func TestChannelHappyPath(t *testing.T) {
} }
func TestChannelBadTLS(t *testing.T) { func TestChannelBadTLS(t *testing.T) {
_, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io") _, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io")
defer close() defer close()
_, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") _, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -78,7 +78,7 @@ func TestChannelBadTLS(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestChannelBadTLS(t *testing.T) {
} }
func TestChannelSessionTimeout(t *testing.T) { func TestChannelSessionTimeout(t *testing.T) {
serverConns, clientURL, close := wireupChannel(envTerminalPath, timeout, "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -96,7 +96,7 @@ func TestChannelSessionTimeout(t *testing.T) { ...@@ -96,7 +96,7 @@ func TestChannelSessionTimeout(t *testing.T) {
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header) hdr := make(http.Header)
hdr.Set("Random-Header", "Value") hdr.Set("Random-Header", "Value")
serverConns, clientURL, close := wireupChannel(envTerminalPath, setHeader(hdr), "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io")
defer close() defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
...@@ -109,7 +109,7 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { ...@@ -109,7 +109,7 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
} }
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
serverConns, clientURL, close := wireupChannel(envTerminalPath, nil, "channel.k8s.io") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io")
defer close() defer close()
hdr := make(http.Header) hdr := make(http.Header)
...@@ -127,13 +127,13 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { ...@@ -127,13 +127,13 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote") require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
} }
func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) { func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
serverConns, remote := startWebsocketServer(subprotocols...) serverConns, remote := startWebsocketServer(subprotocols...)
authResponse := channelOkBody(remote, nil, subprotocols...) authResponse := channelOkBody(remote, nil, subprotocols...)
if modifier != nil { if modifier != nil {
modifier(authResponse) modifier(authResponse)
} }
upstream := testAuthServer(nil, nil, 200, authResponse) upstream := testAuthServer(t, nil, nil, 200, authResponse)
workhorse := startWorkhorseServer(upstream.URL) workhorse := startWorkhorseServer(upstream.URL)
return serverConns, websocketURL(workhorse.URL, channelPath), func() { return serverConns, websocketURL(workhorse.URL, channelPath), func() {
......
...@@ -90,7 +90,7 @@ func TestAllowedClone(t *testing.T) { ...@@ -90,7 +90,7 @@ func TestAllowedClone(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse)) require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -114,7 +114,7 @@ func TestAllowedShallowClone(t *testing.T) { ...@@ -114,7 +114,7 @@ func TestAllowedShallowClone(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse)) require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -138,7 +138,7 @@ func TestAllowedPush(t *testing.T) { ...@@ -138,7 +138,7 @@ func TestAllowedPush(t *testing.T) {
require.NoError(t, ensureGitalyRepository(t, apiResponse)) require.NoError(t, ensureGitalyRepository(t, apiResponse))
// Prepare the test server and backend // Prepare the test server and backend
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
......
...@@ -43,7 +43,7 @@ func TestFailedCloneNoGitaly(t *testing.T) { ...@@ -43,7 +43,7 @@ func TestFailedCloneNoGitaly(t *testing.T) {
} }
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 200, authBody) ts := testAuthServer(t, nil, nil, 200, authBody)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -95,7 +95,7 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
t.Run(fmt.Sprintf("ShowAllRefs=%v,gitRpc=%v", tc.showAllRefs, tc.gitRpc), func(t *testing.T) { t.Run(fmt.Sprintf("ShowAllRefs=%v,gitRpc=%v", tc.showAllRefs, tc.gitRpc), func(t *testing.T) {
apiResponse.ShowAllRefs = tc.showAllRefs apiResponse.ShowAllRefs = tc.showAllRefs
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -147,7 +147,7 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) { ...@@ -147,7 +147,7 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) {
gitalyAddress := "unix:" + socketPath gitalyAddress := "unix:" + socketPath
apiResponse.GitalyServer.Address = gitalyAddress apiResponse.GitalyServer.Address = gitalyAddress
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -187,7 +187,7 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -187,7 +187,7 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) {
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
apiResponse.GitConfigOptions = []string{"git-config-hello=world"} apiResponse.GitConfigOptions = []string{"git-config-hello=world"}
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -232,7 +232,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) { ...@@ -232,7 +232,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
defer gitalyServer.GracefulStop() defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -282,7 +282,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -282,7 +282,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
defer gitalyServer.GracefulStop() defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
...@@ -349,7 +349,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) { ...@@ -349,7 +349,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) {
defer gitalyServer.GracefulStop() defer gitalyServer.GracefulStop()
apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitalyServer.Address = "unix:" + socketPath
ts := testAuthServer(nil, nil, 200, apiResponse) ts := testAuthServer(t, nil, nil, 200, apiResponse)
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
......
...@@ -456,6 +456,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn ...@@ -456,6 +456,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
...@@ -837,11 +838,11 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh ...@@ -837,11 +838,11 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k=
honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k=
honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
...@@ -188,6 +188,8 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error ...@@ -188,6 +188,8 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error
Header: helper.HeaderClone(r.Header), Header: helper.HeaderClone(r.Header),
} }
authReq = authReq.WithContext(r.Context())
// Clean some headers when issuing a new request without body // Clean some headers when issuing a new request without body
authReq.Header.Del("Content-Type") authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding") authReq.Header.Del("Content-Encoding")
......
...@@ -63,7 +63,7 @@ func TestDeniedClone(t *testing.T) { ...@@ -63,7 +63,7 @@ func TestDeniedClone(t *testing.T) {
require.NoError(t, os.RemoveAll(scratchDir)) require.NoError(t, os.RemoveAll(scratchDir))
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(nil, nil, 403, "Access denied") ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -77,7 +77,7 @@ func TestDeniedClone(t *testing.T) { ...@@ -77,7 +77,7 @@ func TestDeniedClone(t *testing.T) {
func TestDeniedPush(t *testing.T) { func TestDeniedPush(t *testing.T) {
// Prepare the test server and backend // Prepare the test server and backend
ts := testAuthServer(nil, nil, 403, "Access denied") ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close() defer ts.Close()
ws := startWorkhorseServer(ts.URL) ws := startWorkhorseServer(ts.URL)
defer ws.Close() defer ws.Close()
...@@ -594,8 +594,10 @@ func newBranch() string { ...@@ -594,8 +594,10 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano()) return fmt.Sprintf("branch-%d", time.Now().UnixNano())
} }
func testAuthServer(url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server { func testAuthServer(t *testing.T, url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server {
return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) { return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
require.NotEmpty(t, r.Header.Get("X-Request-Id"))
w.Header().Set("Content-Type", api.ResponseContentType) w.Header().Set("Content-Type", api.ResponseContentType)
logEntry := log.WithFields(log.Fields{ logEntry := log.WithFields(log.Fields{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment