Commit e381c7aa authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'propagate-correlationid-option' into 'master'

Added configuration option PropagateCorrelationID

See merge request gitlab-org/gitlab-workhorse!529
parents 9c7f15cb 51c7f313
...@@ -74,6 +74,7 @@ type Config struct { ...@@ -74,6 +74,7 @@ type Config struct {
APIQueueTimeout time.Duration `toml:"-"` APIQueueTimeout time.Duration `toml:"-"`
APICILongPollingDuration time.Duration `toml:"-"` APICILongPollingDuration time.Duration `toml:"-"`
ObjectStorageCredentials *ObjectStorageCredentials `toml:"object_storage"` ObjectStorageCredentials *ObjectStorageCredentials `toml:"object_storage"`
PropagateCorrelationID bool `toml:"-"`
} }
// LoadConfig from a file // LoadConfig from a file
......
...@@ -56,8 +56,13 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler { ...@@ -56,8 +56,13 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
up.configureURLPrefix() up.configureURLPrefix()
up.configureRoutes() up.configureRoutes()
var correlationOpts []correlation.InboundHandlerOption
if cfg.PropagateCorrelationID {
correlationOpts = append(correlationOpts, correlation.WithPropagation())
}
handler := log.AccessLogger(&up, log.WithAccessLogger(accessLogger)) handler := log.AccessLogger(&up, log.WithAccessLogger(accessLogger))
handler = correlation.InjectCorrelationID(handler) handler = correlation.InjectCorrelationID(handler, correlationOpts...)
return handler return handler
} }
......
...@@ -57,6 +57,7 @@ var apiLimit = flag.Uint("apiLimit", 0, "Number of API requests allowed at singl ...@@ -57,6 +57,7 @@ var apiLimit = flag.Uint("apiLimit", 0, "Number of API requests allowed at singl
var apiQueueLimit = flag.Uint("apiQueueLimit", 0, "Number of API requests allowed to be queued") var apiQueueLimit = flag.Uint("apiQueueLimit", 0, "Number of API requests allowed to be queued")
var apiQueueTimeout = flag.Duration("apiQueueDuration", queueing.DefaultTimeout, "Maximum queueing duration of requests") var apiQueueTimeout = flag.Duration("apiQueueDuration", queueing.DefaultTimeout, "Maximum queueing duration of requests")
var apiCiLongPollingDuration = flag.Duration("apiCiLongPollingDuration", 50, "Long polling duration for job requesting for runners (default 50s - enabled)") var apiCiLongPollingDuration = flag.Duration("apiCiLongPollingDuration", 50, "Long polling duration for job requesting for runners (default 50s - enabled)")
var propagateCorrelationID = flag.Bool("propagateCorrelationID", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present")
var prometheusListenAddr = flag.String("prometheusListenAddr", "", "Prometheus listening address, e.g. 'localhost:9229'") var prometheusListenAddr = flag.String("prometheusListenAddr", "", "Prometheus listening address, e.g. 'localhost:9229'")
...@@ -155,6 +156,7 @@ func main() { ...@@ -155,6 +156,7 @@ func main() {
APIQueueLimit: *apiQueueLimit, APIQueueLimit: *apiQueueLimit,
APIQueueTimeout: *apiQueueTimeout, APIQueueTimeout: *apiQueueTimeout,
APICILongPollingDuration: *apiCiLongPollingDuration, APICILongPollingDuration: *apiCiLongPollingDuration,
PropagateCorrelationID: *propagateCorrelationID,
} }
if *configFile != "" { if *configFile != "" {
......
...@@ -510,6 +510,52 @@ func TestCorrelationIdHeader(t *testing.T) { ...@@ -510,6 +510,52 @@ func TestCorrelationIdHeader(t *testing.T) {
} }
} }
func TestPropagateCorrelationIdHeader(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Request-Id", r.Header.Get("X-Request-Id"))
w.WriteHeader(200)
})
defer ts.Close()
testCases := []struct {
desc string
propagateCorrelationID bool
}{
{
desc: "propagateCorrelatedId is true",
propagateCorrelationID: true,
},
{
desc: "propagateCorrelatedId is false",
propagateCorrelationID: false,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
upstreamConfig := newUpstreamConfig(ts.URL)
upstreamConfig.PropagateCorrelationID = tc.propagateCorrelationID
ws := startWorkhorseServerWithConfig(upstreamConfig)
defer ws.Close()
resource := "/api/v3/projects/123/repository/not/special"
propagatedRequestId := "Propagated-RequestId-12345678"
resp, _ := httpGet(t, ws.URL+resource, map[string]string{"X-Request-Id": propagatedRequestId})
requestIds := resp.Header["X-Request-Id"]
assert.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
assert.Equal(t, 1, len(requestIds), "GET %q: One X-Request-Id present", resource)
if tc.propagateCorrelationID {
assert.Contains(t, requestIds, propagatedRequestId, "GET %q: Has X-Request-Id %s present", resource, propagatedRequestId)
} else {
assert.NotContains(t, requestIds, propagatedRequestId, "GET %q: X-Request-Id not propagated")
}
})
}
}
func setupStaticFile(fpath, content string) error { func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment