Commit e25a3364 authored by Andrew Newdigate's avatar Andrew Newdigate Committed by Jacob Vosmaer

Pass Correlation-Ids down to backend systems

parent 4e0984d6
......@@ -212,7 +212,7 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error
return authReq, nil
}
// Perform a pre-authorization check against the API for the given HTTP request
// PreAuthorize performs a pre-authorization check against the API for the given HTTP request
//
// If `outErr` is set, the other fields will be nil and it should be treated as
// a 500 error.
......
package correlation
import "bytes"
const base62Chars string = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
// encodeReverseBase62 encodes num into its Base62 reversed representation.
// The most significant value is at the end of the string.
//
// Appending is faster than prepending and this is enough for the purpose of a random ID
func encodeReverseBase62(num int64) string {
if num == 0 {
return "0"
}
encoded := bytes.Buffer{}
for q := num; q > 0; q /= 62 {
encoded.Write([]byte{base62Chars[q%62]})
}
return encoded.String()
}
package log
package correlation
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReverseBase62Conversion(t *testing.T) {
......@@ -22,7 +22,7 @@ func TestReverseBase62Conversion(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%d_to_%s", test.n, test.expected), func(t *testing.T) {
assert.Equal(t, test.expected, encodeReverseBase62(test.n))
require.Equal(t, test.expected, encodeReverseBase62(test.n))
})
}
}
package correlation
import (
"context"
)
type ctxKey int
const keyCorrelationID ctxKey = iota
// ExtractFromContext extracts the CollectionID from the provided context
// Returns an empty string if it's unable to extract the CorrelationID for
// any reason.
func ExtractFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
id := ctx.Value(keyCorrelationID)
str, ok := id.(string)
if !ok {
return ""
}
return str
}
// ContextWithCorrelation will create a new context containing the Correlation-ID value
func ContextWithCorrelation(ctx context.Context, correlationID string) context.Context {
return context.WithValue(ctx, keyCorrelationID, correlationID)
}
package correlation
import (
"context"
"testing"
)
func TestExtractFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
want string
}{
{"nil", nil, ""},
{"missing", context.Background(), ""},
{"set", context.WithValue(context.Background(), keyCorrelationID, "CORRELATION_ID"), "CORRELATION_ID"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtractFromContext(tt.ctx); got != tt.want {
t.Errorf("ExtractFromContext() = %v, want %v", got, tt.want)
}
})
}
}
func TestContextWithCorrelation(t *testing.T) {
tests := []struct {
name string
ctx context.Context
correlationID string
wantValue string
}{
{
name: "nil with value",
ctx: nil,
correlationID: "CORRELATION_ID",
wantValue: "CORRELATION_ID",
},
{
name: "nil with empty string",
ctx: nil,
correlationID: "",
wantValue: "",
},
{
name: "value",
ctx: context.Background(),
correlationID: "CORRELATION_ID",
wantValue: "CORRELATION_ID",
},
{
name: "empty",
ctx: context.Background(),
correlationID: "",
wantValue: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ContextWithCorrelation(tt.ctx, tt.correlationID)
gotValue := got.Value(keyCorrelationID)
if gotValue != tt.wantValue {
t.Errorf("ContextWithCorrelation().Value() = %v, want %v", gotValue, tt.wantValue)
}
})
}
}
package correlation
import (
"crypto/rand"
"fmt"
"log"
"math"
"math/big"
"net/http"
"time"
)
var (
randMax = big.NewInt(math.MaxInt64)
randSource = rand.Reader
)
// generateRandomCorrelationID will attempt to generate a correlationid randomly
// or raise an error
func generateRandomCorrelationID() (string, error) {
id, err := rand.Int(randSource, randMax)
if err != nil {
return "", err
}
base62 := encodeReverseBase62(id.Int64())
return base62, nil
}
func generatePseudorandomCorrelationID(req *http.Request) string {
return fmt.Sprintf("E:%s:%s", req.RemoteAddr, encodeReverseBase62(time.Now().UnixNano()))
}
// generateRandomCorrelationID will attempt to generate a correlationid randomly
// if this fails, will log a message and fallback to a pseudorandom approach
func generateRandomCorrelationIDWithFallback(req *http.Request) string {
correlationID, err := generateRandomCorrelationID()
if err == nil {
return correlationID
}
log.Printf("can't generate random correlation-id: %v", err)
return generatePseudorandomCorrelationID(req)
}
package correlation
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func Test_generateRandomCorrelationID(t *testing.T) {
require := require.New(t)
got, err := generateRandomCorrelationID()
require.NoError(err)
require.NotEqual(got, "", "Expected a non-empty string response")
}
func Test_generatePseudorandomCorrelationID(t *testing.T) {
require := require.New(t)
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err)
got := generatePseudorandomCorrelationID(req)
require.NotEqual(got, "", "Expected a non-empty string response")
require.True(strings.HasPrefix(got, "E:"), "Expected the psuedorandom correlator to have an `E:` prefix")
}
func Test_generateRandomCorrelationIDWithFallback(t *testing.T) {
require := require.New(t)
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err)
got := generateRandomCorrelationIDWithFallback(req)
require.NotEqual(got, "", "Expected a non-empty string response")
require.False(strings.HasPrefix(got, "E:"), "Not expecting fallback to pseudorandom correlationID")
}
package grpccorrelation
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
)
const metadataCorrelatorKey = "X-GitLab-Correlation-ID"
func injectFromContext(ctx context.Context) context.Context {
correlationID := correlation.ExtractFromContext(ctx)
if correlationID != "" {
ctx = metadata.AppendToOutgoingContext(ctx, metadataCorrelatorKey, correlationID)
}
return ctx
}
// UnaryClientCorrelationInterceptor propagates Correlation-IDs downstream
func UnaryClientCorrelationInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = injectFromContext(ctx)
return invoker(ctx, method, req, reply, cc, opts...)
}
// StreamClientCorrelationInterceptor propagates Correlation-IDs downstream
func StreamClientCorrelationInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx = injectFromContext(ctx)
return streamer(ctx, desc, cc, method, opts...)
}
package correlation
import (
"net/http"
)
// InjectCorrelationID middleware will propagate or create a Correlation-ID for the incoming request
func InjectCorrelationID(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
parent := r.Context()
var correlationID = generateRandomCorrelationIDWithFallback(r)
ctx := ContextWithCorrelation(parent, correlationID)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
package log
package correlation
import (
"bytes"
......@@ -8,7 +8,6 @@ import (
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
......@@ -34,7 +33,7 @@ func TestInjectCorrelationID(t *testing.T) {
invoked = true
ctx := r.Context()
correlationID := ctx.Value(KeyCorrelationID)
correlationID := ExtractFromContext(ctx)
require.NotNil(t, correlationID, "CorrelationID is missing")
require.NotEmpty(t, correlationID, "CorrelationID is missing")
}))
......@@ -42,7 +41,7 @@ func TestInjectCorrelationID(t *testing.T) {
r := httptest.NewRequest("GET", "http://example.com", nil)
h.ServeHTTP(nil, r)
assert.True(t, invoked, "handler not executed")
require.True(t, invoked, "handler not executed")
})
}
}
package correlation
import (
"net/http"
)
const propagationHeader = "X-Request-ID"
// injectRequest will pass the CorrelationId through to a downstream http request
// for propagation
func injectRequest(req *http.Request) {
correlationID := ExtractFromContext(req.Context())
if correlationID != "" {
req.Header.Set(propagationHeader, correlationID)
}
}
type instrumentedRoundTripper struct {
delegate http.RoundTripper
}
func (c instrumentedRoundTripper) RoundTrip(req *http.Request) (res *http.Response, e error) {
injectRequest(req)
return c.delegate.RoundTrip(req)
}
// NewInstrumentedRoundTripper acts as a "client-middleware" for outbound http requests
// adding instrumentation to the outbound request and then delegating to the underlying
// transport
func NewInstrumentedRoundTripper(delegate http.RoundTripper) http.RoundTripper {
return &instrumentedRoundTripper{delegate: delegate}
}
package correlation
import (
"context"
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
var httpCorrelationTests = []struct {
name string
ctx context.Context
correlationID string
hasHeader bool
}{
{
name: "nil with value",
ctx: nil,
correlationID: "CORRELATION_ID",
hasHeader: true,
},
{
name: "nil without value",
ctx: nil,
correlationID: "",
hasHeader: false,
},
{
name: "context with value",
ctx: context.Background(),
correlationID: "CORRELATION_ID",
hasHeader: true,
},
{
name: "context without value",
ctx: context.Background(),
correlationID: "",
hasHeader: false,
},
}
func Test_injectRequest(t *testing.T) {
for _, tt := range httpCorrelationTests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID)
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err)
req = req.WithContext(ctx)
injectRequest(req)
value := req.Header.Get(propagationHeader)
require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value)
if tt.hasHeader {
require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value)
}
})
}
}
type delegatedRoundTripper struct {
delegate func(req *http.Request) (*http.Response, error)
}
func (c delegatedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return c.delegate(req)
}
func roundTripperFunc(delegate func(req *http.Request) (*http.Response, error)) http.RoundTripper {
return &delegatedRoundTripper{delegate}
}
func TestInstrumentedRoundTripper(t *testing.T) {
for _, tt := range httpCorrelationTests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
response := &http.Response{}
mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
value := req.Header.Get(propagationHeader)
require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value)
if tt.hasHeader {
require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value)
}
return response, nil
})
client := &http.Client{
Transport: NewInstrumentedRoundTripper(mockTransport),
}
ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID)
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err)
req = req.WithContext(ctx)
res, err := client.Do(req)
require.NoError(err)
require.Equal(response, res)
})
}
}
func TestInstrumentedRoundTripperFailures(t *testing.T) {
for _, tt := range httpCorrelationTests {
t.Run(tt.name+" - with errors", func(t *testing.T) {
require := require.New(t)
testErr := errors.New("test")
mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
value := req.Header.Get(propagationHeader)
require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value)
if tt.hasHeader {
require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value)
}
return nil, testErr
})
client := &http.Client{
Transport: NewInstrumentedRoundTripper(mockTransport),
}
ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID)
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err)
req = req.WithContext(ctx)
res, err := client.Do(req)
require.Error(err)
require.Nil(res)
})
}
}
func TestInstrumentedRoundTripperWithoutContext(t *testing.T) {
require := require.New(t)
response := &http.Response{}
mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
return response, nil
})
client := &http.Client{
Transport: NewInstrumentedRoundTripper(mockTransport),
}
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err)
res, err := client.Do(req)
require.NoError(err)
require.Equal(response, res)
}
package raven
import (
"context"
raven "github.com/getsentry/raven-go"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
)
const ravenSentryExtraKey = "gitlab.CorrelationID"
// SetExtra will augment a raven message with the CorrelationID
// An existing `extra` can be passed in, but if it's nil
// a new one will be created
func SetExtra(ctx context.Context, extra raven.Extra) raven.Extra {
if extra == nil {
extra = raven.Extra{}
}
correlationID := correlation.ExtractFromContext(ctx)
if correlationID != "" {
extra[ravenSentryExtraKey] = correlationID
}
return extra
}
package raven
import (
"context"
"reflect"
"testing"
raven "github.com/getsentry/raven-go"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
)
func TestSetExtra(t *testing.T) {
tests := []struct {
name string
ctx context.Context
extra raven.Extra
want raven.Extra
}{
{
name: "trivial",
ctx: nil,
extra: nil,
want: raven.Extra{},
},
{
name: "no_context",
ctx: nil,
extra: map[string]interface{}{
"key": "value",
},
want: map[string]interface{}{
"key": "value",
},
},
{
name: "context",
ctx: correlation.ContextWithCorrelation(context.Background(), "C001"),
extra: map[string]interface{}{
"key": "value",
},
want: map[string]interface{}{
"key": "value",
ravenSentryExtraKey: "C001",
},
},
{
name: "no_injected_extras",
ctx: correlation.ContextWithCorrelation(context.Background(), "C001"),
extra: nil,
want: map[string]interface{}{
ravenSentryExtraKey: "C001",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := SetExtra(tt.ctx, tt.extra); !reflect.DeepEqual(got, tt.want) {
t.Errorf("SetExtra() = %v, want %v", got, tt.want)
}
})
}
}
......@@ -3,11 +3,14 @@ package gitaly
import (
"sync"
"github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"gitlab.com/gitlab-org/gitaly/auth"
gitalyclient "gitlab.com/gitlab-org/gitaly/client"
"google.golang.org/grpc"
grpccorrelation "gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation/grpc"
)
type Server struct {
......@@ -108,8 +111,19 @@ func CloseConnections() {
func newConnection(server Server) (*grpc.ClientConn, error) {
connOpts := append(gitalyclient.DefaultDialOpts,
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)),
grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor),
grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
grpc.WithStreamInterceptor(
grpc_middleware.ChainStreamClient(
grpc_prometheus.StreamClientInterceptor,
grpccorrelation.StreamClientCorrelationInterceptor,
),
),
grpc.WithUnaryInterceptor(
grpc_middleware.ChainUnaryClient(
grpc_prometheus.UnaryClientInterceptor,
grpccorrelation.UnaryClientCorrelationInterceptor,
),
),
)
return gitalyclient.Dial(server.Address, connOpts)
......
......@@ -5,6 +5,8 @@ import (
"reflect"
"github.com/getsentry/raven-go"
correlation "gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation/raven"
)
var ravenHeaderBlacklist = []string{
......@@ -14,11 +16,14 @@ var ravenHeaderBlacklist = []string{
func captureRavenError(r *http.Request, err error) {
client := raven.DefaultClient
extra := raven.Extra{}
interfaces := []raven.Interface{}
if r != nil {
CleanHeadersForRaven(r)
interfaces = append(interfaces, raven.NewHttp(r))
extra = correlation.SetExtra(r.Context(), extra)
}
exception := &raven.Exception{
......@@ -28,7 +33,7 @@ func captureRavenError(r *http.Request, err error) {
}
interfaces = append(interfaces, exception)
packet := raven.NewPacket(err.Error(), interfaces...)
packet := raven.NewPacketWithExtra(err.Error(), extra, interfaces...)
client.Capture(packet, nil)
}
......
package log
import (
"bytes"
"context"
"crypto/rand"
"fmt"
"math"
"math/big"
"net/http"
"time"
)
type ctxKey string
const (
// KeyCorrelationID const is the context key for Correlation ID
KeyCorrelationID ctxKey = "X-Correlation-ID"
base62Chars string = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
var (
randMax = big.NewInt(math.MaxInt64)
randSource = rand.Reader
)
func InjectCorrelationID(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
parent := r.Context()
correlationID, err := generateRandomCorrelationID()
if err != nil {
correlationID = fmt.Sprintf("E:%s:%s", r.RemoteAddr, encodeReverseBase62(time.Now().UnixNano()))
NoContext().WithError(err).Warning("Can't generate random correlation-id")
}
ctx := context.WithValue(parent, KeyCorrelationID, correlationID)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
func generateRandomCorrelationID() (string, error) {
id, err := rand.Int(randSource, randMax)
if err != nil {
return "", err
}
base62 := encodeReverseBase62(id.Int64())
return base62, nil
}
// encodeReverseBase62 encodes num into its Base62 reversed representation.
// The most significant value is at the end of the string.
//
// Appending is faster than prepending and this is enough for the purpose of a random ID
func encodeReverseBase62(num int64) string {
if num == 0 {
return "0"
}
encoded := bytes.Buffer{}
for q := num; q > 0; q /= 62 {
encoded.Write([]byte{base62Chars[q%62]})
}
return encoded.String()
}
......@@ -4,6 +4,8 @@ import (
"context"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
)
// Fields type, an helper to avoid importing logrus.Fields
......@@ -19,19 +21,11 @@ func toLogrusFields(f Fields) logrus.Fields {
}
func getCorrelationID(ctx context.Context) string {
noID := "[MISSING]"
if ctx == nil {
return noID
correlationID := correlation.ExtractFromContext(ctx)
if correlationID == "" {
return "[MISSING]"
}
id := ctx.Value(KeyCorrelationID)
str, ok := id.(string)
if !ok {
return noID
}
return str
return correlationID
}
// WithContext provides a *logrus.Entry with the proper "correlation-id" field.
......
......@@ -8,13 +8,13 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/log"
)
func requireCorrelationID(t *testing.T, getEntry func(ctx context.Context) *logrus.Entry) *logrus.Entry {
id := "test-id"
ctx := context.WithValue(context.Background(), log.KeyCorrelationID, id)
ctx := correlation.ContextWithCorrelation(context.Background(), id)
e := getEntry(ctx)
require.NotNil(t, e)
......
......@@ -9,6 +9,7 @@ import (
"net/http"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
......@@ -16,7 +17,7 @@ import (
// that are more restrictive than for http.DefaultTransport,
// they define shorter TLS Handshake, and more agressive connection closing
// to prevent the connection hanging and reduce FD usage
var httpTransport = &http.Transport{
var httpTransport = correlation.NewInstrumentedRoundTripper(&http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
......@@ -27,7 +28,7 @@ var httpTransport = &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
}
})
var httpClient = &http.Client{
Transport: httpTransport,
......
......@@ -79,6 +79,7 @@ func (u *uploader) syncAndDelete(url string) {
log.WithError(err).WithField("object", helper.ScrubURLParams(url)).Warning("Delete failed")
return
}
// TODO: consider adding the context to the outgoing request for better instrumentation
// here we are not using u.ctx because we must perform cleanup regardless of parent context
resp, err := httpClient.Do(req)
......
......@@ -9,6 +9,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/log"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
......@@ -36,7 +37,7 @@ var rangeHeaderKeys = []string{
// that are more restrictive than for http.DefaultTransport,
// they define shorter TLS Handshake, and more agressive connection closing
// to prevent the connection hanging and reduce FD usage
var httpTransport = &http.Transport{
var httpTransport = correlation.NewInstrumentedRoundTripper(&http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
......@@ -47,7 +48,7 @@ var httpTransport = &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
}
})
var httpClient = &http.Client{
Transport: httpTransport,
......@@ -115,6 +116,7 @@ func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string)
helper.Fail500(w, r, fmt.Errorf("SendURL: NewRequest: %v", err))
return
}
newReq = newReq.WithContext(r.Context())
for _, header := range rangeHeaderKeys {
newReq.Header[header] = r.Header[header]
......
......@@ -9,6 +9,7 @@ import (
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
)
func mustParseAddress(address, scheme string) string {
......@@ -45,7 +46,7 @@ func NewBackendRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout
panic("backend is nil and socket is empty")
}
return badgateway.NewRoundTripper(developmentMode, transport)
return correlation.NewInstrumentedRoundTripper(badgateway.NewRoundTripper(developmentMode, transport))
}
// NewTestBackendRoundTripper sets up a RoundTripper for testing purposes
......
......@@ -34,6 +34,11 @@ type routeEntry struct {
matchers []matcherFunc
}
type routeOptions struct {
tracing bool
matchers []matcherFunc
}
const (
apiPattern = `^/api/`
ciAPIPattern = `^/ci/api/`
......@@ -49,12 +54,36 @@ func compileRegexp(regexpStr string) *regexp.Regexp {
return regexp.MustCompile(regexpStr)
}
func route(method, regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry {
func withMatcher(f matcherFunc) func(*routeOptions) {
return func(options *routeOptions) {
options.matchers = append(options.matchers, f)
}
}
func withoutTracing() func(*routeOptions) {
return func(options *routeOptions) {
options.tracing = false
}
}
func route(method, regexpStr string, handler http.Handler, opts ...func(*routeOptions)) routeEntry {
// Instantiate a route with the defaults
options := routeOptions{
tracing: true,
}
for _, f := range opts {
f(&options)
}
handler = denyWebsocket(handler) // Disallow websockets
handler = instrumentRoute(handler, method, regexpStr) // Add prometheus metrics
return routeEntry{
method: method,
regex: compileRegexp(regexpStr),
handler: instrumentRoute(denyWebsocket(handler), method, regexpStr),
matchers: matchers,
handler: handler,
matchers: options.matchers,
}
}
......@@ -130,9 +159,9 @@ func (u *upstream) configureRoutes() {
u.Routes = []routeEntry{
// Git Clone
route("GET", gitProjectPattern+`info/refs\z`, git.GetInfoRefsHandler(api)),
route("POST", gitProjectPattern+`git-upload-pack\z`, contentEncodingHandler(git.UploadPack(api)), isContentType("application/x-git-upload-pack-request")),
route("POST", gitProjectPattern+`git-receive-pack\z`, contentEncodingHandler(git.ReceivePack(api)), isContentType("application/x-git-receive-pack-request")),
route("PUT", gitProjectPattern+`gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`, lfs.PutStore(api, proxy), isContentType("application/octet-stream")),
route("POST", gitProjectPattern+`git-upload-pack\z`, contentEncodingHandler(git.UploadPack(api)), withMatcher(isContentType("application/x-git-upload-pack-request"))),
route("POST", gitProjectPattern+`git-receive-pack\z`, contentEncodingHandler(git.ReceivePack(api)), withMatcher(isContentType("application/x-git-receive-pack-request"))),
route("PUT", gitProjectPattern+`gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`, lfs.PutStore(api, proxy), withMatcher(isContentType("application/octet-stream"))),
// CI Artifacts
route("POST", apiPattern+`v4/jobs/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, proxy))),
......@@ -161,6 +190,7 @@ func (u *upstream) configureRoutes() {
staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode, proxy),
),
withoutTracing(), // Tracing on assets is very noisy
),
// Uploads
......
......@@ -13,6 +13,7 @@ import (
"github.com/jfbus/httprs"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
......@@ -23,7 +24,7 @@ var ErrNotAZip = errors.New("not a zip")
var ErrArchiveNotFound = errors.New("archive not found")
var httpClient = &http.Client{
Transport: &http.Transport{
Transport: correlation.NewInstrumentedRoundTripper(&http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
......@@ -33,7 +34,7 @@ var httpClient = &http.Client{
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
},
}),
}
// OpenArchive will open a zip.Reader from a local path or a remote object store URL
......@@ -58,6 +59,7 @@ func openHTTPArchive(ctx context.Context, archivePath string) (*zip.Reader, erro
if err != nil {
return nil, fmt.Errorf("Can't create HTTP GET %q: %v", scrubbedArchivePath, err)
}
req = req.WithContext(ctx)
resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
......
......@@ -26,6 +26,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/correlation"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/log"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/redis"
......@@ -147,7 +148,7 @@ func main() {
}
}
up := wrapRaven(log.InjectCorrelationID(upstream.NewUpstream(cfg)))
up := wrapRaven(correlation.InjectCorrelationID(upstream.NewUpstream(cfg)))
logger.Fatal(http.Serve(listener, up))
}
# grpc_ctxtags
`import "github.com/grpc-ecosystem/go-grpc-middleware/tags"`
* [Overview](#pkg-overview)
* [Imported Packages](#pkg-imports)
* [Index](#pkg-index)
* [Examples](#pkg-examples)
## <a name="pkg-overview">Overview</a>
`grpc_ctxtags` adds a Tag object to the context that can be used by other middleware to add context about a request.
### Request Context Tags
Tags describe information about the request, and can be set and used by other middleware, or handlers. Tags are used
for logging and tracing of requests. Tags are populated both upwards, *and* downwards in the interceptor-handler stack.
You can automatically extract tags (in `grpc.request.<field_name>`) from request payloads.
For unary and server-streaming methods, pass in the `WithFieldExtractor` option. For client-streams and bidirectional-streams, you can
use `WithFieldExtractorForInitialReq` which will extract the tags from the first message passed from client to server.
Note the tags will not be modified for subsequent requests, so this option only makes sense when the initial message
establishes the meta-data for the stream.
If a user doesn't use the interceptors that initialize the `Tags` object, all operations following from an `Extract(ctx)`
will be no-ops. This is to ensure that code doesn't panic if the interceptors weren't used.
Tags fields are typed, and shallow and should follow the OpenTracing semantics convention:
<a href="https://github.com/opentracing/specification/blob/master/semantic_conventions.md">https://github.com/opentracing/specification/blob/master/semantic_conventions.md</a>
#### Example:
<details>
<summary>Click to expand code.</summary>
```go
opts := []grpc_ctxtags.Option{
grpc_ctxtags.WithFieldExtractorForInitialReq(grpc_ctxtags.TagBasedRequestFieldExtractor("log_fields")),
}
_ = grpc.NewServer(
grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)),
grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)),
)
```
</details>
#### Example:
<details>
<summary>Click to expand code.</summary>
```go
opts := []grpc_ctxtags.Option{
grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.TagBasedRequestFieldExtractor("log_fields")),
}
_ = grpc.NewServer(
grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)),
grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)),
)
```
</details>
## <a name="pkg-imports">Imported Packages</a>
- [github.com/grpc-ecosystem/go-grpc-middleware](./..)
- [golang.org/x/net/context](https://godoc.org/golang.org/x/net/context)
- [google.golang.org/grpc](https://godoc.org/google.golang.org/grpc)
- [google.golang.org/grpc/peer](https://godoc.org/google.golang.org/grpc/peer)
## <a name="pkg-index">Index</a>
* [Variables](#pkg-variables)
* [func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{}](#CodeGenRequestFieldExtractor)
* [func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor](#StreamServerInterceptor)
* [func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor](#UnaryServerInterceptor)
* [type Option](#Option)
* [func WithFieldExtractor(f RequestFieldExtractorFunc) Option](#WithFieldExtractor)
* [func WithFieldExtractorForInitialReq(f RequestFieldExtractorFunc) Option](#WithFieldExtractorForInitialReq)
* [type RequestFieldExtractorFunc](#RequestFieldExtractorFunc)
* [func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc](#TagBasedRequestFieldExtractor)
* [type Tags](#Tags)
* [func Extract(ctx context.Context) Tags](#Extract)
#### <a name="pkg-examples">Examples</a>
* [Package (InitialisationWithOptions)](#example__initialisationWithOptions)
* [Package (Initialization)](#example__initialization)
#### <a name="pkg-files">Package files</a>
[context.go](./context.go) [doc.go](./doc.go) [fieldextractor.go](./fieldextractor.go) [interceptors.go](./interceptors.go) [options.go](./options.go)
## <a name="pkg-variables">Variables</a>
``` go
var (
// NoopTags is a trivial, minimum overhead implementation of Tags for which all operations are no-ops.
NoopTags = &noopTags{}
)
```
## <a name="CodeGenRequestFieldExtractor">func</a> [CodeGenRequestFieldExtractor](./fieldextractor.go#L23)
``` go
func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{}
```
CodeGenRequestFieldExtractor is a function that relies on code-generated functions that export log fields from requests.
These are usually coming from a protoc-plugin that generates additional information based on custom field options.
## <a name="StreamServerInterceptor">func</a> [StreamServerInterceptor](./interceptors.go#L26)
``` go
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor
```
StreamServerInterceptor returns a new streaming server interceptor that sets the values for request tags.
## <a name="UnaryServerInterceptor">func</a> [UnaryServerInterceptor](./interceptors.go#L14)
``` go
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor
```
UnaryServerInterceptor returns a new unary server interceptors that sets the values for request tags.
## <a name="Option">type</a> [Option](./options.go#L26)
``` go
type Option func(*options)
```
### <a name="WithFieldExtractor">func</a> [WithFieldExtractor](./options.go#L30)
``` go
func WithFieldExtractor(f RequestFieldExtractorFunc) Option
```
WithFieldExtractor customizes the function for extracting log fields from protobuf messages, for
unary and server-streamed methods only.
### <a name="WithFieldExtractorForInitialReq">func</a> [WithFieldExtractorForInitialReq](./options.go#L39)
``` go
func WithFieldExtractorForInitialReq(f RequestFieldExtractorFunc) Option
```
WithFieldExtractorForInitialReq customizes the function for extracting log fields from protobuf messages,
for all unary and streaming methods. For client-streams and bidirectional-streams, the tags will be
extracted from the first message from the client.
## <a name="RequestFieldExtractorFunc">type</a> [RequestFieldExtractorFunc](./fieldextractor.go#L13)
``` go
type RequestFieldExtractorFunc func(fullMethod string, req interface{}) map[string]interface{}
```
RequestFieldExtractorFunc is a user-provided function that extracts field information from a gRPC request.
It is called from tags middleware on arrival of unary request or a server-stream request.
Keys and values will be added to the context tags of the request. If there are no fields, you should return a nil.
### <a name="TagBasedRequestFieldExtractor">func</a> [TagBasedRequestFieldExtractor](./fieldextractor.go#L43)
``` go
func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc
```
TagBasedRequestFieldExtractor is a function that relies on Go struct tags to export log fields from requests.
These are usually coming from a protoc-plugin, such as Gogo protobuf.
message Metadata {
repeated string tags = 1 [ (gogoproto.moretags) = "log_field:\"meta_tags\"" ];
}
The tagName is configurable using the tagName variable. Here it would be "log_field".
## <a name="Tags">type</a> [Tags](./context.go#L19-L27)
``` go
type Tags interface {
// Set sets the given key in the metadata tags.
Set(key string, value interface{}) Tags
// Has checks if the given key exists.
Has(key string) bool
// Values returns a map of key to values.
// Do not modify the underlying map, please use Set instead.
Values() map[string]interface{}
}
```
Tags is the interface used for storing request tags between Context calls.
The default implementation is *not* thread safe, and should be handled only in the context of the request.
### <a name="Extract">func</a> [Extract](./context.go#L63)
``` go
func Extract(ctx context.Context) Tags
```
Extracts returns a pre-existing Tags object in the Context.
If the context wasn't set in a tag interceptor, a no-op Tag storage is returned that will *not* be propagated in context.
- - -
Generated by [godoc2ghmd](https://github.com/GandalfUK/godoc2ghmd)
\ No newline at end of file
# grpc_ctxtags
`import "github.com/grpc-ecosystem/go-grpc-middleware/tags"`
* [Overview](#pkg-overview)
* [Imported Packages](#pkg-imports)
* [Index](#pkg-index)
* [Examples](#pkg-examples)
## <a name="pkg-overview">Overview</a>
`grpc_ctxtags` adds a Tag object to the context that can be used by other middleware to add context about a request.
### Request Context Tags
Tags describe information about the request, and can be set and used by other middleware, or handlers. Tags are used
for logging and tracing of requests. Tags are populated both upwards, *and* downwards in the interceptor-handler stack.
You can automatically extract tags (in `grpc.request.<field_name>`) from request payloads.
For unary and server-streaming methods, pass in the `WithFieldExtractor` option. For client-streams and bidirectional-streams, you can
use `WithFieldExtractorForInitialReq` which will extract the tags from the first message passed from client to server.
Note the tags will not be modified for subsequent requests, so this option only makes sense when the initial message
establishes the meta-data for the stream.
If a user doesn't use the interceptors that initialize the `Tags` object, all operations following from an `Extract(ctx)`
will be no-ops. This is to ensure that code doesn't panic if the interceptors weren't used.
Tags fields are typed, and shallow and should follow the OpenTracing semantics convention:
<a href="https://github.com/opentracing/specification/blob/master/semantic_conventions.md">https://github.com/opentracing/specification/blob/master/semantic_conventions.md</a>
#### Example:
<details>
<summary>Click to expand code.</summary>
```go
opts := []grpc_ctxtags.Option{
grpc_ctxtags.WithFieldExtractorForInitialReq(grpc_ctxtags.TagBasedRequestFieldExtractor("log_fields")),
}
_ = grpc.NewServer(
grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)),
grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)),
)
```
</details>
#### Example:
<details>
<summary>Click to expand code.</summary>
```go
opts := []grpc_ctxtags.Option{
grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.TagBasedRequestFieldExtractor("log_fields")),
}
_ = grpc.NewServer(
grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)),
grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)),
)
```
</details>
## <a name="pkg-imports">Imported Packages</a>
- [github.com/grpc-ecosystem/go-grpc-middleware](./..)
- [golang.org/x/net/context](https://godoc.org/golang.org/x/net/context)
- [google.golang.org/grpc](https://godoc.org/google.golang.org/grpc)
- [google.golang.org/grpc/peer](https://godoc.org/google.golang.org/grpc/peer)
## <a name="pkg-index">Index</a>
* [Variables](#pkg-variables)
* [func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{}](#CodeGenRequestFieldExtractor)
* [func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor](#StreamServerInterceptor)
* [func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor](#UnaryServerInterceptor)
* [type Option](#Option)
* [func WithFieldExtractor(f RequestFieldExtractorFunc) Option](#WithFieldExtractor)
* [func WithFieldExtractorForInitialReq(f RequestFieldExtractorFunc) Option](#WithFieldExtractorForInitialReq)
* [type RequestFieldExtractorFunc](#RequestFieldExtractorFunc)
* [func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc](#TagBasedRequestFieldExtractor)
* [type Tags](#Tags)
* [func Extract(ctx context.Context) Tags](#Extract)
#### <a name="pkg-examples">Examples</a>
* [Package (InitialisationWithOptions)](#example__initialisationWithOptions)
* [Package (Initialization)](#example__initialization)
#### <a name="pkg-files">Package files</a>
[context.go](./context.go) [doc.go](./doc.go) [fieldextractor.go](./fieldextractor.go) [interceptors.go](./interceptors.go) [options.go](./options.go)
## <a name="pkg-variables">Variables</a>
``` go
var (
// NoopTags is a trivial, minimum overhead implementation of Tags for which all operations are no-ops.
NoopTags = &noopTags{}
)
```
## <a name="CodeGenRequestFieldExtractor">func</a> [CodeGenRequestFieldExtractor](./fieldextractor.go#L23)
``` go
func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{}
```
CodeGenRequestFieldExtractor is a function that relies on code-generated functions that export log fields from requests.
These are usually coming from a protoc-plugin that generates additional information based on custom field options.
## <a name="StreamServerInterceptor">func</a> [StreamServerInterceptor](./interceptors.go#L26)
``` go
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor
```
StreamServerInterceptor returns a new streaming server interceptor that sets the values for request tags.
## <a name="UnaryServerInterceptor">func</a> [UnaryServerInterceptor](./interceptors.go#L14)
``` go
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor
```
UnaryServerInterceptor returns a new unary server interceptors that sets the values for request tags.
## <a name="Option">type</a> [Option](./options.go#L26)
``` go
type Option func(*options)
```
### <a name="WithFieldExtractor">func</a> [WithFieldExtractor](./options.go#L30)
``` go
func WithFieldExtractor(f RequestFieldExtractorFunc) Option
```
WithFieldExtractor customizes the function for extracting log fields from protobuf messages, for
unary and server-streamed methods only.
### <a name="WithFieldExtractorForInitialReq">func</a> [WithFieldExtractorForInitialReq](./options.go#L39)
``` go
func WithFieldExtractorForInitialReq(f RequestFieldExtractorFunc) Option
```
WithFieldExtractorForInitialReq customizes the function for extracting log fields from protobuf messages,
for all unary and streaming methods. For client-streams and bidirectional-streams, the tags will be
extracted from the first message from the client.
## <a name="RequestFieldExtractorFunc">type</a> [RequestFieldExtractorFunc](./fieldextractor.go#L13)
``` go
type RequestFieldExtractorFunc func(fullMethod string, req interface{}) map[string]interface{}
```
RequestFieldExtractorFunc is a user-provided function that extracts field information from a gRPC request.
It is called from tags middleware on arrival of unary request or a server-stream request.
Keys and values will be added to the context tags of the request. If there are no fields, you should return a nil.
### <a name="TagBasedRequestFieldExtractor">func</a> [TagBasedRequestFieldExtractor](./fieldextractor.go#L43)
``` go
func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc
```
TagBasedRequestFieldExtractor is a function that relies on Go struct tags to export log fields from requests.
These are usually coming from a protoc-plugin, such as Gogo protobuf.
message Metadata {
repeated string tags = 1 [ (gogoproto.moretags) = "log_field:\"meta_tags\"" ];
}
The tagName is configurable using the tagName variable. Here it would be "log_field".
## <a name="Tags">type</a> [Tags](./context.go#L19-L27)
``` go
type Tags interface {
// Set sets the given key in the metadata tags.
Set(key string, value interface{}) Tags
// Has checks if the given key exists.
Has(key string) bool
// Values returns a map of key to values.
// Do not modify the underlying map, please use Set instead.
Values() map[string]interface{}
}
```
Tags is the interface used for storing request tags between Context calls.
The default implementation is *not* thread safe, and should be handled only in the context of the request.
### <a name="Extract">func</a> [Extract](./context.go#L63)
``` go
func Extract(ctx context.Context) Tags
```
Extracts returns a pre-existing Tags object in the Context.
If the context wasn't set in a tag interceptor, a no-op Tag storage is returned that will *not* be propagated in context.
- - -
Generated by [godoc2ghmd](https://github.com/GandalfUK/godoc2ghmd)
\ No newline at end of file
package grpc_ctxtags
import (
"context"
)
type ctxMarker struct{}
var (
// ctxMarkerKey is the Context value marker used by *all* logging middleware.
// The logging middleware object must interf
ctxMarkerKey = &ctxMarker{}
// NoopTags is a trivial, minimum overhead implementation of Tags for which all operations are no-ops.
NoopTags = &noopTags{}
)
// Tags is the interface used for storing request tags between Context calls.
// The default implementation is *not* thread safe, and should be handled only in the context of the request.
type Tags interface {
// Set sets the given key in the metadata tags.
Set(key string, value interface{}) Tags
// Has checks if the given key exists.
Has(key string) bool
// Values returns a map of key to values.
// Do not modify the underlying map, please use Set instead.
Values() map[string]interface{}
}
type mapTags struct {
values map[string]interface{}
}
func (t *mapTags) Set(key string, value interface{}) Tags {
t.values[key] = value
return t
}
func (t *mapTags) Has(key string) bool {
_, ok := t.values[key]
return ok
}
func (t *mapTags) Values() map[string]interface{} {
return t.values
}
type noopTags struct{}
func (t *noopTags) Set(key string, value interface{}) Tags {
return t
}
func (t *noopTags) Has(key string) bool {
return false
}
func (t *noopTags) Values() map[string]interface{} {
return nil
}
// Extracts returns a pre-existing Tags object in the Context.
// If the context wasn't set in a tag interceptor, a no-op Tag storage is returned that will *not* be propagated in context.
func Extract(ctx context.Context) Tags {
t, ok := ctx.Value(ctxMarkerKey).(Tags)
if !ok {
return NoopTags
}
return t
}
func setInContext(ctx context.Context, tags Tags) context.Context {
return context.WithValue(ctx, ctxMarkerKey, tags)
}
func newTags() Tags {
return &mapTags{values: make(map[string]interface{})}
}
/*
`grpc_ctxtags` adds a Tag object to the context that can be used by other middleware to add context about a request.
Request Context Tags
Tags describe information about the request, and can be set and used by other middleware, or handlers. Tags are used
for logging and tracing of requests. Tags are populated both upwards, *and* downwards in the interceptor-handler stack.
You can automatically extract tags (in `grpc.request.<field_name>`) from request payloads.
For unary and server-streaming methods, pass in the `WithFieldExtractor` option. For client-streams and bidirectional-streams, you can
use `WithFieldExtractorForInitialReq` which will extract the tags from the first message passed from client to server.
Note the tags will not be modified for subsequent requests, so this option only makes sense when the initial message
establishes the meta-data for the stream.
If a user doesn't use the interceptors that initialize the `Tags` object, all operations following from an `Extract(ctx)`
will be no-ops. This is to ensure that code doesn't panic if the interceptors weren't used.
Tags fields are typed, and shallow and should follow the OpenTracing semantics convention:
https://github.com/opentracing/specification/blob/master/semantic_conventions.md
*/
package grpc_ctxtags
// Copyright 2017 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_ctxtags
import (
"reflect"
)
// RequestFieldExtractorFunc is a user-provided function that extracts field information from a gRPC request.
// It is called from tags middleware on arrival of unary request or a server-stream request.
// Keys and values will be added to the context tags of the request. If there are no fields, you should return a nil.
type RequestFieldExtractorFunc func(fullMethod string, req interface{}) map[string]interface{}
type requestFieldsExtractor interface {
// ExtractRequestFields is a method declared on a Protobuf message that extracts fields from the interface.
// The values from the extracted fields should be set in the appendToMap, in order to avoid allocations.
ExtractRequestFields(appendToMap map[string]interface{})
}
// CodeGenRequestFieldExtractor is a function that relies on code-generated functions that export log fields from requests.
// These are usually coming from a protoc-plugin that generates additional information based on custom field options.
func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{} {
if ext, ok := req.(requestFieldsExtractor); ok {
retMap := make(map[string]interface{})
ext.ExtractRequestFields(retMap)
if len(retMap) == 0 {
return nil
}
return retMap
}
return nil
}
// TagBasedRequestFieldExtractor is a function that relies on Go struct tags to export log fields from requests.
// These are usually coming from a protoc-plugin, such as Gogo protobuf.
//
// message Metadata {
// repeated string tags = 1 [ (gogoproto.moretags) = "log_field:\"meta_tags\"" ];
// }
//
// The tagName is configurable using the tagName variable. Here it would be "log_field".
func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc {
return func(fullMethod string, req interface{}) map[string]interface{} {
retMap := make(map[string]interface{})
reflectMessageTags(req, retMap, tagName)
if len(retMap) == 0 {
return nil
}
return retMap
}
}
func reflectMessageTags(msg interface{}, existingMap map[string]interface{}, tagName string) {
v := reflect.ValueOf(msg)
// Only deal with pointers to structs.
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return
}
// Deref the pointer get to the struct.
v = v.Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
kind := field.Kind()
// Only recurse down direct pointers, which should only be to nested structs.
if kind == reflect.Ptr {
reflectMessageTags(field.Interface(), existingMap, tagName)
}
// In case of arrays/splices (repeated fields) go down to the concrete type.
if kind == reflect.Array || kind == reflect.Slice {
if field.Len() == 0 {
continue
}
kind = field.Index(0).Kind()
}
// Only be interested in
if (kind >= reflect.Bool && kind <= reflect.Float64) || kind == reflect.String {
if tag := t.Field(i).Tag.Get(tagName); tag != "" {
existingMap[tag] = field.Interface()
}
}
}
return
}
// Copyright 2017 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_ctxtags
import (
"github.com/grpc-ecosystem/go-grpc-middleware"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
// UnaryServerInterceptor returns a new unary server interceptors that sets the values for request tags.
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOptions(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newCtx := newTagsForCtx(ctx)
if o.requestFieldsFunc != nil {
setRequestFieldTags(newCtx, o.requestFieldsFunc, info.FullMethod, req)
}
return handler(newCtx, req)
}
}
// StreamServerInterceptor returns a new streaming server interceptor that sets the values for request tags.
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOptions(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
newCtx := newTagsForCtx(stream.Context())
if o.requestFieldsFunc == nil {
// Short-circuit, don't do the expensive bit of allocating a wrappedStream.
wrappedStream := grpc_middleware.WrapServerStream(stream)
wrappedStream.WrappedContext = newCtx
return handler(srv, wrappedStream)
}
wrapped := &wrappedStream{stream, info, o, newCtx, true}
err := handler(srv, wrapped)
return err
}
}
// wrappedStream is a thin wrapper around grpc.ServerStream that allows modifying context and extracts log fields from the initial message.
type wrappedStream struct {
grpc.ServerStream
info *grpc.StreamServerInfo
opts *options
// WrappedContext is the wrapper's own Context. You can assign it.
WrappedContext context.Context
initial bool
}
// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
func (w *wrappedStream) Context() context.Context {
return w.WrappedContext
}
func (w *wrappedStream) RecvMsg(m interface{}) error {
err := w.ServerStream.RecvMsg(m)
// We only do log fields extraction on the single-request of a server-side stream.
if !w.info.IsClientStream || w.opts.requestFieldsFromInitial && w.initial {
w.initial = false
setRequestFieldTags(w.Context(), w.opts.requestFieldsFunc, w.info.FullMethod, m)
}
return err
}
func newTagsForCtx(ctx context.Context) context.Context {
t := newTags()
if peer, ok := peer.FromContext(ctx); ok {
t.Set("peer.address", peer.Addr.String())
}
return setInContext(ctx, t)
}
func setRequestFieldTags(ctx context.Context, f RequestFieldExtractorFunc, fullMethodName string, req interface{}) {
if valMap := f(fullMethodName, req); valMap != nil {
t := Extract(ctx)
for k, v := range valMap {
t.Set("grpc.request."+k, v)
}
}
}
// Copyright 2017 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_ctxtags
var (
defaultOptions = &options{
requestFieldsFunc: nil,
}
)
type options struct {
requestFieldsFunc RequestFieldExtractorFunc
requestFieldsFromInitial bool
}
func evaluateOptions(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
return optCopy
}
type Option func(*options)
// WithFieldExtractor customizes the function for extracting log fields from protobuf messages, for
// unary and server-streamed methods only.
func WithFieldExtractor(f RequestFieldExtractorFunc) Option {
return func(o *options) {
o.requestFieldsFunc = f
}
}
// WithFieldExtractorForInitialReq customizes the function for extracting log fields from protobuf messages,
// for all unary and streaming methods. For client-streams and bidirectional-streams, the tags will be
// extracted from the first message from the client.
func WithFieldExtractorForInitialReq(f RequestFieldExtractorFunc) Option {
return func(o *options) {
o.requestFieldsFunc = f
o.requestFieldsFromInitial = true
}
}
......@@ -165,6 +165,14 @@
"version": "v1",
"versionExact": "v1.0.0"
},
{
"checksumSHA1": "Rf3QgJeAX2809t/DZvMjZbGHe9U=",
"path": "github.com/grpc-ecosystem/go-grpc-middleware/tags",
"revision": "c250d6563d4d4c20252cd865923440e829844f4e",
"revisionTime": "2018-05-02T09:16:42Z",
"version": "v1",
"versionExact": "v1.0.0"
},
{
"checksumSHA1": "L5z1C445GhhQmDKSisTFv754LdU=",
"path": "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment