Commit a58ceca7 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'refactor-upstream' into 'master'

Refactor upstream

Rationale: the code has become a tangled mess of global variables and
types that hang together when they need not. For example: every HTTP
handler uses a 'gitRequest'?? I want to clean this up and see if I can
move some things into internal packages.

Apart from using internal packages we now use http.Handler where we can,
and fewer global variables.

See merge request !20
parents 64270207 9e7b612b
test/data
test/scratch
gitlab-workhorse
test/public
testdata/data
testdata/scratch
testdata/public
PREFIX=/usr/local
VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S)
gitlab-workhorse: $(wildcard *.go)
gitlab-workhorse: $(shell find . -name '*.go')
go build -ldflags "-X main.Version=${VERSION}" -o gitlab-workhorse
install: gitlab-workhorse
install gitlab-workhorse ${PREFIX}/bin/
.PHONY: test
test: test/data/group/test.git clean-workhorse gitlab-workhorse
go fmt | awk '{ print "Please run go fmt"; exit 1 }'
go test
test: testdata/data/group/test.git clean-workhorse gitlab-workhorse
go fmt ./... | awk '{ print } END { if (NR > 0) { print "Please run go fmt"; exit 1 } }'
go test ./...
@echo SUCCESS
coverage: test/data/group/test.git
coverage: testdata/data/group/test.git
go test -cover -coverprofile=test.coverage
go tool cover -html=test.coverage -o coverage.html
rm -f test.coverage
test/data/group/test.git: test/data
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/group/test.git
testdata/data/group/test.git: testdata/data
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git $@
test/data:
mkdir -p test/data
testdata/data:
mkdir -p $@
.PHONY: clean
clean: clean-workhorse
rm -rf test/data test/scratch
rm -rf testdata/data testdata/scratch
.PHONY: clean-workhorse
clean-workhorse:
......
package main
func artifactsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(handleFunc, "/authorize")
}
package main
import (
"./internal/api"
"./internal/helper"
"fmt"
"net/http"
"net/http/httptest"
......@@ -8,14 +10,14 @@ import (
"testing"
)
func okHandler(w http.ResponseWriter, r *gitRequest) {
func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
w.WriteHeader(201)
fmt.Fprint(w, "{\"status\":\"ok\"}")
}
func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, authorizationResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
// Prepare test server and backend
ts := testAuthServer(url, returnCode, authorizationResponse)
ts := testAuthServer(url, returnCode, apiResponse)
defer ts.Close()
// Create http request
......@@ -23,15 +25,11 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut
if err != nil {
t.Fatal(err)
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
}
a := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil)
response := httptest.NewRecorder()
preAuthorizeHandler(okHandler, suffix)(response, &request)
assertResponseCode(t, response, expectedCode)
a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, expectedCode)
return response
}
......@@ -39,7 +37,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) {
runPreAuthorizeHandler(
t, "/authorize",
regexp.MustCompile(`/authorize\z`),
&authorizationResponse{},
&api.Response{},
200, 201)
}
......@@ -47,7 +45,7 @@ func TestPreAuthorizeSuffix(t *testing.T) {
runPreAuthorizeHandler(
t, "/different-authorize",
regexp.MustCompile(`/authorize\z`),
&authorizationResponse{},
&api.Response{},
200, 404)
}
......
package main
import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
if !*developmentMode {
http.NotFound(w, r.Request)
return
}
handler(w, r)
}
}
package main
package api
import (
"../badgateway"
"../helper"
"../proxy"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
url := u.authBackend + "/" + strings.TrimPrefix(r.URL.RequestURI(), u.relativeURLRoot) + suffix
authReq, err := http.NewRequest(r.Method, url, body)
if err != nil {
return nil, err
type API struct {
Client *http.Client
URL *url.URL
Version string
}
func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API {
if roundTripper == nil {
roundTripper = badgateway.NewRoundTripper("", 0)
}
// Forward all headers from our client to the auth backend. This includes
// HTTP Basic authentication credentials (the 'Authorization' header).
for k, v := range r.Header {
authReq.Header[k] = v
return &API{
Client: &http.Client{Transport: roundTripper},
URL: myURL,
Version: version,
}
}
type HandleFunc func(http.ResponseWriter, *http.Request, *Response)
type Response struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
}
// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// rebaseUrl is taken from reverseproxy.go:NewSingleHostReverseProxy
func rebaseUrl(url *url.URL, onto *url.URL, suffix string) *url.URL {
newUrl := *url
newUrl.Scheme = onto.Scheme
newUrl.Host = onto.Host
if suffix != "" {
newUrl.Path = singleJoiningSlash(url.Path, suffix)
}
if onto.RawQuery == "" || newUrl.RawQuery == "" {
newUrl.RawQuery = onto.RawQuery + newUrl.RawQuery
} else {
newUrl.RawQuery = onto.RawQuery + "&" + newUrl.RawQuery
}
return &newUrl
}
func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
authReq := &http.Request{
Method: r.Method,
URL: rebaseUrl(r.URL, api.URL, suffix),
Header: proxy.HeaderClone(r.Header),
}
if body != nil {
authReq.Body = ioutil.NopCloser(body)
}
// Clean some headers when issuing a new request without body
......@@ -46,22 +125,22 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st
authReq.Host = r.Host
// Set a custom header for the request. This can be used in some
// configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", Version)
authReq.Header.Set("Gitlab-Workhorse", api.Version)
return authReq, nil
}
func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix)
func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authReq, err := api.newRequest(r, nil, suffix)
if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
return
}
authResponse, err := r.u.httpClient.Do(authReq)
authResponse, err := api.Client.Do(authReq)
if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err))
helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err))
return
}
defer authResponse.Body.Close()
......@@ -85,11 +164,12 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
return
}
a := &Response{}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(authResponse.Body).Decode(&r.authorizationResponse); err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
if err := json.NewDecoder(authResponse.Body).Decode(a); err != nil {
helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return
}
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
......@@ -104,6 +184,6 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
}
}
handleFunc(w, r)
}
h(w, r, a)
})
}
package main
package badgateway
import (
"../helper"
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"time"
)
type proxyRoundTripper struct {
transport http.RoundTripper
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = p.transport.RoundTrip(r)
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
type RoundTripper struct {
Transport *http.Transport
}
func NewRoundTripper(socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
tr := *DefaultTransport
tr.ResponseHeaderTimeout = proxyHeadersTimeout
if socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", socket)
}
}
return &RoundTripper{Transport: &tr}
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = t.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
......@@ -21,7 +48,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil {
logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{
StatusCode: http.StatusBadGateway,
......@@ -40,26 +67,3 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
}
return
}
func headerClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func proxyRequest(w http.ResponseWriter, r *gitRequest) {
// Clone request
req := *r.Request
req.Header = headerClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
r.u.httpProxy.ServeHTTP(&rw, &req)
}
......@@ -2,9 +2,11 @@
In this file we handle 'git archive' downloads
*/
package main
package git
import (
"../api"
"../helper"
"fmt"
"io"
"io/ioutil"
......@@ -18,7 +20,10 @@ import (
"time"
)
func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
func GetArchive(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handleGetArchive)
}
func handleGetArchive(w http.ResponseWriter, r *http.Request, a *api.Response) {
var format string
urlPath := r.URL.Path
switch filepath.Base(urlPath) {
......@@ -31,20 +36,20 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
case "archive.tar.bz2":
format = "tar.bz2"
default:
fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath))
helper.Fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath))
return
}
archiveFilename := path.Base(r.ArchivePath)
archiveFilename := path.Base(a.ArchivePath)
if cachedArchive, err := os.Open(r.ArchivePath); err == nil {
if cachedArchive, err := os.Open(a.ArchivePath); err == nil {
defer cachedArchive.Close()
log.Printf("Serving cached file %q", r.ArchivePath)
log.Printf("Serving cached file %q", a.ArchivePath)
setArchiveHeaders(w, format, archiveFilename)
// Even if somebody deleted the cachedArchive from disk since we opened
// the file, Unix file semantics guarantee we can still read from the
// open file in this process.
http.ServeContent(w, r.Request, "", time.Unix(0, 0), cachedArchive)
http.ServeContent(w, r, "", time.Unix(0, 0), cachedArchive)
return
}
......@@ -52,9 +57,9 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
// safe. We create the tempfile in the same directory as the final cached
// archive we want to create so that we can use an atomic link(2) operation
// to finalize the cached archive.
tempFile, err := prepareArchiveTempfile(path.Dir(r.ArchivePath), archiveFilename)
tempFile, err := prepareArchiveTempfile(path.Dir(a.ArchivePath), archiveFilename)
if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err))
helper.Fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err))
return
}
defer tempFile.Close()
......@@ -62,15 +67,15 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
compressCmd, archiveFormat := parseArchiveFormat(format)
archiveCmd := gitCommand("", "git", "--git-dir="+r.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+r.ArchivePrefix+"/", r.CommitId)
archiveCmd := gitCommand("", "git", "--git-dir="+a.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+a.ArchivePrefix+"/", a.CommitId)
archiveStdout, err := archiveCmd.StdoutPipe()
if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err))
helper.Fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err))
return
}
defer archiveStdout.Close()
if err := archiveCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err))
helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err))
return
}
defer cleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up
......@@ -84,13 +89,13 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
stdout, err = compressCmd.StdoutPipe()
if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err))
helper.Fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err))
return
}
defer stdout.Close()
if err := compressCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err))
helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err))
return
}
defer cleanUpProcessGroup(compressCmd)
......@@ -105,22 +110,22 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
setArchiveHeaders(w, format, archiveFilename)
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if _, err := io.Copy(w, archiveReader); err != nil {
logError(fmt.Errorf("handleGetArchive: read: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: read: %v", err))
return
}
if err := archiveCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
return
}
if compressCmd != nil {
if err := compressCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
return
}
}
if err := finalizeCachedArchive(tempFile, r.ArchivePath); err != nil {
logError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
if err := finalizeCachedArchive(tempFile, a.ArchivePath); err != nil {
helper.LogError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
return
}
}
......
/*
Miscellaneous helpers: logging, errors, subprocesses
*/
package main
package git
import (
"errors"
"fmt"
"log"
"net/http"
"os"
"os/exec"
"path"
"syscall"
)
func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
logError(err)
}
func logError(err error) {
log.Printf("error: %v", err)
}
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
// Git subprocess helpers
func gitCommand(gl_id string, name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...)
......@@ -64,57 +38,3 @@ func cleanUpProcessGroup(cmd *exec.Cmd) {
// reap our child process
cmd.Wait()
}
func setNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
func openFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func cleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
......@@ -2,9 +2,11 @@
In this file we handle the Git 'smart HTTP' protocol
*/
package main
package git
import (
"../api"
"../helper"
"errors"
"fmt"
"io"
......@@ -16,6 +18,14 @@ import (
"strings"
)
func GetInfoRefs(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handleGetInfoRefs)
}
func PostRPC(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handlePostRPC)
}
func looksLikeRepo(p string) bool {
// If /path/to/foo.git/objects exists then let's assume it is a valid Git
// repository.
......@@ -26,23 +36,23 @@ func looksLikeRepo(p string) bool {
return true
}
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.RepoPath == "" {
helper.Fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return
}
if !looksLikeRepo(r.RepoPath) {
if !looksLikeRepo(a.RepoPath) {
http.Error(w, "Not Found", 404)
return
}
handleFunc(w, r)
handleFunc(w, r, a)
}, "")
}
func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) {
rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported
......@@ -51,15 +61,15 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
}
// Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", r.RepoPath)
cmd := gitCommand(a.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", a.RepoPath)
stdout, err := cmd.StdoutPipe()
if err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err))
helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err))
return
}
defer stdout.Close()
if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err))
helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err))
return
}
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
......@@ -69,57 +79,57 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
return
}
if err := pktFlush(w); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
return
}
if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
return
}
if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
return
}
}
func handlePostRPC(w http.ResponseWriter, r *gitRequest) {
func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
var err error
// Get Git action from URL
action := filepath.Base(r.URL.Path)
if !(action == "git-upload-pack" || action == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported
fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path))
helper.Fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path))
return
}
// Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(action), "--stateless-rpc", r.RepoPath)
cmd := gitCommand(a.GL_ID, "git", subCommand(action), "--stateless-rpc", a.RepoPath)
stdout, err := cmd.StdoutPipe()
if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err))
helper.Fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err))
return
}
defer stdout.Close()
stdin, err := cmd.StdinPipe()
if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err))
helper.Fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err))
return
}
defer stdin.Close()
if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err))
helper.Fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err))
return
}
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Write the client request body to Git's standard input
if _, err := io.Copy(stdin, r.Body); err != nil {
fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err))
helper.Fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err))
return
}
// Signal to the Git subprocess that no more data is coming
......@@ -136,11 +146,11 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest) {
// This io.Copy may take a long time, both for Git push and pull.
if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
return
}
if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
return
}
}
......
package helper
import (
"errors"
"log"
"net/http"
"net/url"
"os"
)
func Fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
LogError(err)
}
func LogError(err error) {
log.Printf("error: %v", err)
}
func SetNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
func URLMustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
log.Fatalf("urlMustParse: %q %v", s, err)
}
return u
}
func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
package main
package helper
import (
"fmt"
......@@ -6,25 +6,25 @@ import (
"time"
)
type loggingResponseWriter struct {
type LoggingResponseWriter struct {
rw http.ResponseWriter
status int
written int64
started time.Time
}
func newLoggingResponseWriter(rw http.ResponseWriter) loggingResponseWriter {
return loggingResponseWriter{
func NewLoggingResponseWriter(rw http.ResponseWriter) LoggingResponseWriter {
return LoggingResponseWriter{
rw: rw,
started: time.Now(),
}
}
func (l *loggingResponseWriter) Header() http.Header {
func (l *LoggingResponseWriter) Header() http.Header {
return l.rw.Header()
}
func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
if l.status == 0 {
l.WriteHeader(http.StatusOK)
}
......@@ -33,7 +33,7 @@ func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
return
}
func (l *loggingResponseWriter) WriteHeader(status int) {
func (l *LoggingResponseWriter) WriteHeader(status int) {
if l.status != 0 {
return
}
......@@ -42,7 +42,7 @@ func (l *loggingResponseWriter) WriteHeader(status int) {
l.rw.WriteHeader(status)
}
func (l *loggingResponseWriter) Log(r *http.Request) {
func (l *LoggingResponseWriter) Log(r *http.Request) {
duration := time.Since(l.started)
fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n",
r.Host, r.RemoteAddr, l.started,
......
package main
package helper
import (
"log"
"net/http"
"net/http/httptest"
"regexp"
"testing"
)
func assertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) {
func AssertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) {
if response.Code != expectedCode {
t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code)
}
}
func assertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
if response.Body.String() != expectedBody {
t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String())
}
}
func assertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) {
func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) {
if response.Header().Get(header) != expectedValue {
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
}
}
func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
/*
In this file we handle git lfs objects downloads and uploads
*/
package lfs
import (
"../api"
"../helper"
"../proxy"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
func PutStore(a *api.API, p *proxy.Proxy) http.Handler {
return lfsAuthorizeHandler(a, handleStoreLfsObject(p))
}
func lfsAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.StoreLFSPath == "" {
helper.Fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if a.LfsOid == "" {
helper.Fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(a.StoreLFSPath, 0700); err != nil {
helper.Fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
}
handleFunc(w, r, a)
}, "/authorize")
}
func handleStoreLfsObject(h http.Handler) api.HandleFunc {
return func(w http.ResponseWriter, r *http.Request, a *api.Response) {
file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid)
if err != nil {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return
}
file.Close()
if written != a.LfsSize {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", a.LfsSize, written))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != a.LfsOid {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", a.LfsOid, shaStr))
return
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// And proxy the request
h.ServeHTTP(w, r)
}
}
package proxy
import (
"../badgateway"
"net/http"
"net/http/httputil"
"net/url"
)
type Proxy struct {
Version string
reverseProxy *httputil.ReverseProxy
}
func NewProxy(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *Proxy {
p := Proxy{Version: version}
u := *myURL // Make a copy of p.URL
u.Path = ""
p.reverseProxy = httputil.NewSingleHostReverseProxy(&u)
if roundTripper != nil {
p.reverseProxy.Transport = roundTripper
} else {
p.reverseProxy.Transport = badgateway.NewRoundTripper("", 0)
}
return &p
}
func HeaderClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Clone request
req := *r
req.Header = HeaderClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", p.Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
p.reverseProxy.ServeHTTP(&rw, &req)
}
......@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method.
*/
package main
package proxy
import (
"../helper"
"log"
"net/http"
)
......@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
// Serve the file
log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI)
content, fi, err := openFile(file)
content, fi, err := helper.OpenFile(file)
if err != nil {
http.NotFound(s.rw, s.req)
return
......
package main
package staticpages
import (
"../helper"
"io/ioutil"
"net/http"
"path/filepath"
)
func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
deployPage := filepath.Join(*documentRoot, "index.html")
func (s *Static) DeployPage(handler http.Handler) http.Handler {
deployPage := filepath.Join(s.DocumentRoot, "index.html")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadFile(deployPage)
if err != nil {
handler(w, r)
handler.ServeHTTP(w, r)
return
}
setNoCacheHeaders(w.Header())
helper.SetNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(data)
}
})
}
package main
package staticpages
import (
"../helper"
"io/ioutil"
"net/http"
"net/http/httptest"
......@@ -19,9 +20,10 @@ func TestIfNoDeployPageExist(t *testing.T) {
w := httptest.NewRecorder()
executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) {
st := &Static{dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true
})(w, nil)
})).ServeHTTP(w, nil)
if !executed {
t.Error("The handler should get executed")
}
......@@ -40,14 +42,15 @@ func TestIfDeployPageExist(t *testing.T) {
w := httptest.NewRecorder()
executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) {
st := &Static{dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true
})(w, nil)
})).ServeHTTP(w, nil)
if executed {
t.Error("The handler should not get executed")
}
w.Flush()
assertResponseCode(t, w, 200)
assertResponseBody(t, w, deployPage)
helper.AssertResponseCode(t, w, 200)
helper.AssertResponseBody(t, w, deployPage)
}
package main
package staticpages
import (
"../helper"
"fmt"
"io/ioutil"
"log"
......@@ -12,7 +13,7 @@ type errorPageResponseWriter struct {
rw http.ResponseWriter
status int
hijacked bool
path *string
path string
}
func (s *errorPageResponseWriter) Header() http.Header {
......@@ -37,14 +38,14 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status
if 400 <= s.status && s.status <= 599 {
errorPageFile := filepath.Join(*s.path, fmt.Sprintf("%d.html", s.status))
errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.hijacked = true
log.Printf("ErrorPage: serving predefined error page: %d", s.status)
setNoCacheHeaders(s.rw.Header())
helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
......@@ -59,16 +60,16 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK)
}
func handleRailsError(documentRoot *string, enabled *bool, handler serviceHandleFunc) serviceHandleFunc {
if !*enabled {
func (st *Static) ErrorPages(enabled bool, handler http.Handler) http.Handler {
if !enabled {
return handler
}
return func(w http.ResponseWriter, r *gitRequest) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{
rw: w,
path: documentRoot,
path: st.DocumentRoot,
}
defer rw.Flush()
handler(&rw, r)
}
handler.ServeHTTP(&rw, r)
})
}
package main
package staticpages
import (
"../helper"
"fmt"
"io/ioutil"
"net/http"
......@@ -21,16 +22,16 @@ func TestIfErrorPageIsPresented(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder()
enabled := true
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
fmt.Fprint(w, "Not Found")
})(w, nil)
})
st := &Static{dir}
st.ErrorPages(true, h).ServeHTTP(w, nil)
w.Flush()
assertResponseCode(t, w, 404)
assertResponseBody(t, w, errorPage)
helper.AssertResponseCode(t, w, 404)
helper.AssertResponseBody(t, w, errorPage)
}
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
......@@ -42,16 +43,16 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
w := httptest.NewRecorder()
errorResponse := "ERROR"
enabled := true
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
fmt.Fprint(w, errorResponse)
})(w, nil)
})
st := &Static{dir}
st.ErrorPages(true, h).ServeHTTP(w, nil)
w.Flush()
assertResponseCode(t, w, 404)
assertResponseBody(t, w, errorResponse)
helper.AssertResponseCode(t, w, 404)
helper.AssertResponseBody(t, w, errorResponse)
}
func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
......@@ -65,15 +66,14 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder()
enabled := false
serverError := "Interesting Server Error"
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(500)
fmt.Fprint(w, serverError)
})(w, nil)
})
st := &Static{dir}
st.ErrorPages(false, h).ServeHTTP(w, nil)
w.Flush()
assertResponseCode(t, w, 500)
assertResponseBody(t, w, serverError)
helper.AssertResponseCode(t, w, 500)
helper.AssertResponseBody(t, w, serverError)
}
package main
package staticpages
import (
"../helper"
"../urlprefix"
"log"
"net/http"
"os"
......@@ -19,13 +21,13 @@ const (
// BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists,
// handleServeFile will serve foo/bar instead of passing the request
// upstream.
func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
file := filepath.Join(*documentRoot, r.relativeURIPath)
func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path))
// The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) {
fail500(w, &os.PathError{
if !strings.HasPrefix(file, s.DocumentRoot) {
helper.Fail500(w, &os.PathError{
Op: "open",
Path: file,
Err: os.ErrInvalid,
......@@ -39,7 +41,7 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
// Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
content, fi, err = openFile(file + ".gz")
content, fi, err = helper.OpenFile(file + ".gz")
if err == nil {
w.Header().Set("Content-Encoding", "gzip")
}
......@@ -47,13 +49,13 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
// If not found, open the original file
if content == nil || err != nil {
content, fi, err = openFile(file)
content, fi, err = helper.OpenFile(file)
}
if err != nil {
if notFoundHandler != nil {
notFoundHandler(w, r)
notFoundHandler.ServeHTTP(w, r)
} else {
http.NotFound(w, r.Request)
http.NotFound(w, r)
}
return
}
......@@ -68,6 +70,6 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
}
log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI)
http.ServeContent(w, r.Request, filepath.Base(file), fi.ModTime(), content)
}
http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content)
})
}
package main
package staticpages
import (
"../helper"
"bytes"
"compress/gzip"
"io/ioutil"
......@@ -14,14 +15,11 @@ import (
func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 404)
st := &Static{dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 404)
}
func TestServingDirectory(t *testing.T) {
......@@ -32,41 +30,31 @@ func TestServingDirectory(t *testing.T) {
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/",
}
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 404)
st := &Static{dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 404)
}
func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/../../../static/file",
}
httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 500)
st := &Static{dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 404)
}
func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
executed := false
handleServeFile(&dir, CacheDisabled, func(w http.ResponseWriter, r *gitRequest) {
executed = (r == request)
})(nil, request)
st := &Static{dir}
st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
executed = (r == httpRequest)
})).ServeHTTP(nil, httpRequest)
if !executed {
t.Error("The handler should get executed")
}
......@@ -80,17 +68,14 @@ func TestServingTheActualFile(t *testing.T) {
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
fileContent := "STATIC"
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 200)
st := &Static{dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 200)
if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String())
}
......@@ -104,10 +89,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
if enableGzip {
httpRequest.Header.Set("Accept-Encoding", "gzip, deflate")
......@@ -124,16 +105,17 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 200)
st := &Static{dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 200)
if enableGzip {
assertResponseHeader(t, w, "Content-Encoding", "gzip")
helper.AssertResponseHeader(t, w, "Content-Encoding", "gzip")
if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 {
t.Error("We should serve the pregzipped file")
}
} else {
assertResponseCode(t, w, 200)
assertResponseHeader(t, w, "Content-Encoding", "")
helper.AssertResponseCode(t, w, 200)
helper.AssertResponseHeader(t, w, "Content-Encoding", "")
if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String())
}
......
package staticpages
type Static struct {
DocumentRoot string
}
package upload
import (
"../api"
"net/http"
)
func Artifacts(myAPI *api.API, h http.Handler) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
r.Header.Set(tempPathHeader, a.TempPath)
handleFileUploads(h).ServeHTTP(w, r)
}, "/authorize")
}
package main
package upload
import (
"../helper"
"bytes"
"errors"
"fmt"
......@@ -11,7 +12,9 @@ import (
"os"
)
func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cleanup func(), err error) {
const tempPathHeader = "Gitlab-Workhorse-Temp-Path"
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string) (cleanup func(), err error) {
// Create multipart reader
reader, err := r.MultipartReader()
if err != nil {
......@@ -47,12 +50,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
// Copy form field
if filename := p.FileName(); filename != "" {
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(r.TempPath, 0700); err != nil {
if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, err
}
// Create temporary file in path returned by Authorization filter
file, err := ioutil.TempFile(r.TempPath, "upload_")
file, err := ioutil.TempFile(tempPath, "upload_")
if err != nil {
return cleanup, err
}
......@@ -83,39 +86,43 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
return cleanup, nil
}
func handleFileUploads(w http.ResponseWriter, r *gitRequest) {
if r.TempPath == "" {
fail500(w, errors.New("handleFileUploads: TempPath empty"))
return
}
func handleFileUploads(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tempPath := r.Header.Get(tempPathHeader)
if tempPath == "" {
helper.Fail500(w, errors.New("handleFileUploads: TempPath empty"))
return
}
r.Header.Del(tempPathHeader)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
defer writer.Close()
var body bytes.Buffer
writer := multipart.NewWriter(&body)
defer writer.Close()
// Rewrite multipart form data
cleanup, err := rewriteFormFilesFromMultipart(r, writer)
if err != nil {
if err == http.ErrNotMultipart {
proxyRequest(w, r)
} else {
fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
// Rewrite multipart form data
cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath)
if err != nil {
if err == http.ErrNotMultipart {
h.ServeHTTP(w, r)
} else {
helper.Fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
}
return
}
return
}
if cleanup != nil {
defer cleanup()
}
if cleanup != nil {
defer cleanup()
}
// Close writer
writer.Close()
// Close writer
writer.Close()
// Hijack the request
r.Body = ioutil.NopCloser(&body)
r.ContentLength = int64(body.Len())
r.Header.Set("Content-Type", writer.FormDataContentType())
// Hijack the request
r.Body = ioutil.NopCloser(&body)
r.ContentLength = int64(body.Len())
r.Header.Set("Content-Type", writer.FormDataContentType())
// Proxy the request
proxyRequest(w, r)
// Proxy the request
h.ServeHTTP(w, r)
})
}
package main
package upload
import (
"../helper"
"../proxy"
"bytes"
"fmt"
"io"
......@@ -14,19 +16,17 @@ import (
"testing"
)
var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder()
request := gitRequest{
authorizationResponse: authorizationResponse{
TempPath: "",
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 500)
request := &http.Request{}
handleFileUploads(nilHandler).ServeHTTP(response, request)
helper.AssertResponseCode(t, response, 500)
}
func TestUploadHandlerForwardingRawData(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Fatal("Expected PATCH request")
}
......@@ -40,6 +40,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE")
})
defer ts.Close()
httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil {
......@@ -53,15 +54,11 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
defer os.RemoveAll(tempPath)
response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
authorizationResponse: authorizationResponse{
TempPath: tempPath,
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 202)
httpRequest.Header.Set(tempPathHeader, tempPath)
handleFileUploads(proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body")
}
......@@ -76,7 +73,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
}
defer os.RemoveAll(tempPath)
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Fatal("Expected PUT request")
}
......@@ -131,17 +128,11 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Body = ioutil.NopCloser(&buffer)
httpRequest.ContentLength = int64(buffer.Len())
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpRequest.Header.Set(tempPathHeader, tempPath)
response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
authorizationResponse: authorizationResponse{
TempPath: tempPath,
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 202)
handleFileUploads(proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) {
t.Fatal("expected the file to be deleted")
......
package main
package upstream
import (
"../helper"
"net/http"
"net/http/httptest"
"testing"
......@@ -13,9 +14,9 @@ func TestDevelopmentModeEnabled(t *testing.T) {
w := httptest.NewRecorder()
executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) {
NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true
})(w, &gitRequest{Request: r})
})).ServeHTTP(w, r)
if !executed {
t.Error("The handler should get executed")
}
......@@ -28,11 +29,11 @@ func TestDevelopmentModeDisabled(t *testing.T) {
w := httptest.NewRecorder()
executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) {
NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true
})(w, &gitRequest{Request: r})
})).ServeHTTP(w, r)
if executed {
t.Error("The handler should not get executed")
}
assertResponseCode(t, w, 404)
helper.AssertResponseCode(t, w, 404)
}
package main
package upstream
import (
"../helper"
"compress/gzip"
"fmt"
"io"
"net/http"
)
func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
func contentEncodingHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body io.ReadCloser
var err error
......@@ -24,7 +25,7 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
}
if err != nil {
fail500(w, fmt.Errorf("contentEncodingHandler: %v", err))
helper.Fail500(w, fmt.Errorf("contentEncodingHandler: %v", err))
return
}
defer body.Close()
......@@ -32,6 +33,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
r.Body = body
r.Header.Del("Content-Encoding")
handleFunc(w, r)
}
h.ServeHTTP(w, r)
})
}
package main
package upstream
import (
"../helper"
"bytes"
"compress/gzip"
"fmt"
......@@ -27,17 +28,16 @@ func TestGzipEncoding(t *testing.T) {
}
req.Header.Set("Content-Encoding", "gzip")
request := gitRequest{Request: req}
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if _, ok := r.Body.(*gzip.Reader); !ok {
t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body))
}
if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted")
}
})(resp, &request)
})).ServeHTTP(resp, req)
assertResponseCode(t, resp, 200)
helper.AssertResponseCode(t, resp, 200)
}
func TestNoEncoding(t *testing.T) {
......@@ -52,17 +52,16 @@ func TestNoEncoding(t *testing.T) {
}
req.Header.Set("Content-Encoding", "")
request := gitRequest{Request: req}
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if r.Body != body {
t.Fatal("Expected the same body")
}
if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted")
}
})(resp, &request)
})).ServeHTTP(resp, req)
assertResponseCode(t, resp, 200)
helper.AssertResponseCode(t, resp, 200)
}
func TestInvalidEncoding(t *testing.T) {
......@@ -74,10 +73,9 @@ func TestInvalidEncoding(t *testing.T) {
}
req.Header.Set("Content-Encoding", "application/unknown")
request := gitRequest{Request: req}
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
t.Fatal("it shouldn't be executed")
})(resp, &request)
})).ServeHTTP(resp, req)
assertResponseCode(t, resp, 500)
helper.AssertResponseCode(t, resp, 500)
}
package upstream
import "net/http"
func NotFoundUnless(pass bool, handler http.Handler) http.Handler {
if pass {
return handler
} else {
return http.HandlerFunc(http.NotFound)
}
}
package upstream
import (
apipkg "../api"
"../git"
"../lfs"
proxypkg "../proxy"
"../staticpages"
"../upload"
"net/http"
"regexp"
)
type route struct {
method string
regex *regexp.Regexp
handler http.Handler
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
// A project ID in an API request is either a number or two strings 'namespace/project'
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
func (u *Upstream) configureRoutes() {
api := apipkg.NewAPI(
u.Backend,
u.Version,
u.RoundTripper,
)
static := &staticpages.Static{u.DocumentRoot}
proxy := proxypkg.NewProxy(
u.Backend,
u.Version,
u.RoundTripper,
)
u.Routes = []route{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(api)},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(api, proxy)},
// Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(api)},
// Repository Archive API
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(api)},
// CI Artifacts API
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(api, proxy))},
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), proxy},
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix, staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode,
proxy,
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
route{"", regexp.MustCompile(`^/uploads/`), static.ErrorPages(u.DevelopmentMode, proxy)},
// Serve static files or forward the requests
route{"", nil,
static.ServeExisting(u.URLPrefix, staticpages.CacheDisabled,
static.DeployPage(
static.ErrorPages(u.DevelopmentMode,
proxy,
),
),
),
},
}
}
/*
The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend.
*/
package upstream
import (
"../badgateway"
"../helper"
"../urlprefix"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
var DefaultBackend = helper.URLMustParse("http://localhost:8080")
type Upstream struct {
Backend *url.URL
Version string
DocumentRoot string
DevelopmentMode bool
URLPrefix urlprefix.Prefix
Routes []route
RoundTripper *badgateway.RoundTripper
}
func NewUpstream(backend *url.URL, socket string, version string, documentRoot string, developmentMode bool, proxyHeadersTimeout time.Duration) *Upstream {
up := Upstream{
Backend: backend,
Version: version,
DocumentRoot: documentRoot,
DevelopmentMode: developmentMode,
RoundTripper: badgateway.NewRoundTripper(socket, proxyHeadersTimeout),
}
if backend == nil {
up.Backend = DefaultBackend
}
up.configureRoutes()
up.configureURLPrefix()
return &up
}
func (u *Upstream) configureURLPrefix() {
relativeURLRoot := u.Backend.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
u.URLPrefix = urlprefix.Prefix(relativeURLRoot)
}
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := helper.NewLoggingResponseWriter(ow)
defer w.Log(r)
// Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
helper.HTTPError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
helper.HTTPError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := urlprefix.CleanURIPath(r.URL.Path)
prefix := u.URLPrefix
if !prefix.Match(URIPath) {
helper.HTTPError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Look for a matching Git service
var ro route
foundService := false
for _, ro = range u.Routes {
if ro.method != "" && r.Method != ro.method {
continue
}
if ro.regex == nil || ro.regex.MatchString(prefix.Strip(URIPath)) {
foundService = true
break
}
}
if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
helper.HTTPError(&w, r, "Forbidden", http.StatusForbidden)
return
}
ro.handler.ServeHTTP(&w, r)
}
package urlprefix
import (
"path"
"strings"
)
type Prefix string
func (p Prefix) Strip(path string) string {
return CleanURIPath(strings.TrimPrefix(path, string(p)))
}
func (p Prefix) Match(path string) bool {
pre := string(p)
return strings.HasPrefix(path, pre) || path+"/" == pre
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func CleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
/*
In this file we handle git lfs objects downloads and uploads
*/
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
func lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.StoreLFSPath == "" {
fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if r.LfsOid == "" {
fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(r.StoreLFSPath, 0700); err != nil {
fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
}
handleFunc(w, r)
}, "/authorize")
}
func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
file, err := ioutil.TempFile(r.StoreLFSPath, r.LfsOid)
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return
}
file.Close()
if written != r.LfsSize {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", r.LfsSize, written))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != r.LfsOid {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", r.LfsOid, shaStr))
return
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// And proxy the request
proxyRequest(w, r)
}
......@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type.
package main
import (
"./internal/upstream"
"flag"
"fmt"
"log"
......@@ -21,7 +22,6 @@ import (
"net/http"
_ "net/http/pprof"
"os"
"regexp"
"syscall"
"time"
)
......@@ -33,95 +33,13 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend")
var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request")
var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
type httpRoute struct {
method string
regex *regexp.Regexp
handleFunc serviceHandleFunc
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
// A project ID in an API request is either a number or two strings 'namespace/project'
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
var httpRoutes = [...]httpRoute{
// Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)},
// Repository Archive
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// Repository Archive API
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))},
// Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), proxyRequest},
httpRoute{"", regexp.MustCompile(ciAPIPattern), proxyRequest},
// Serve assets
httpRoute{"", regexp.MustCompile(`^/assets/`),
handleServeFile(documentRoot, CacheExpireMax,
handleDevelopmentMode(developmentMode,
handleDeployPage(documentRoot,
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
),
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through handleServeFile.
httpRoute{"", regexp.MustCompile(`^/uploads/`),
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
},
// Serve static files or forward the requests
httpRoute{"", nil,
handleServeFile(documentRoot, CacheDisabled,
handleDeployPage(documentRoot,
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
),
),
},
}
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
......@@ -153,23 +71,6 @@ func main() {
log.Fatal(err)
}
// Create Proxy Transport
authTransport := http.DefaultTransport
if *authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", *authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := &proxyRoundTripper{transport: authTransport}
// The profiler will only be activated by HTTP requests. HTTP
// requests can only reach the profiler if we start a listener. So by
// having no profiler HTTP listener by default, the profiler is
......@@ -180,6 +81,14 @@ func main() {
}()
}
upstream := newUpstream(*authBackend, proxyTransport)
log.Fatal(http.Serve(listener, upstream))
up := upstream.NewUpstream(
*authBackend,
*authSocket,
Version,
*documentRoot,
*developmentMode,
*proxyHeadersTimeout,
)
log.Fatal(http.Serve(listener, up))
}
package main
import (
"./internal/api"
"./internal/helper"
"./internal/upstream"
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
......@@ -18,9 +22,9 @@ import (
"time"
)
const scratchDir = "test/scratch"
const testRepoRoot = "test/data"
const testDocumentRoot = "test/public"
const scratchDir = "testdata/scratch"
const testRepoRoot = "testdata/data"
const testDocumentRoot = "testdata/public"
const testRepo = "group/test.git"
const testProject = "group/test"
......@@ -325,7 +329,7 @@ func TestAllowedStaticFile(t *testing.T) {
}
proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true
w.WriteHeader(404)
})
......@@ -339,21 +343,21 @@ func TestAllowedStaticFile(t *testing.T) {
} {
resp, err := http.Get(ws.URL + resource)
if err != nil {
t.Fatal(err)
t.Error(err)
}
defer resp.Body.Close()
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, resp.Body); err != nil {
t.Fatal(err)
t.Error(err)
}
if buf.String() != content {
t.Fatalf("GET %q: Expected %q, got %q", resource, content, buf.String())
t.Errorf("GET %q: Expected %q, got %q", resource, content, buf.String())
}
if resp.StatusCode != 200 {
t.Fatalf("GET %q: expected 200, got %d", resource, resp.StatusCode)
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
if proxied {
t.Fatalf("GET %q: should not have made it to backend", resource)
t.Errorf("GET %q: should not have made it to backend", resource)
}
}
}
......@@ -365,7 +369,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) {
}
proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true
w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path)
w.WriteHeader(200)
......@@ -406,7 +410,7 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
}
proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
proxied = true
w.WriteHeader(404)
})
......@@ -439,6 +443,50 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
}
}
func TestArtifactsUpload(t *testing.T) {
reqBody := &bytes.Buffer{}
writer := multipart.NewWriter(reqBody)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
writer.Close()
ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
if len(r.MultipartForm.Value) != 2 { // 1 file name, 1 file path
t.Error("Expected to receive exactly 2 values")
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
w.WriteHeader(200)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, writer.FormDataContentType(), reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd()
if err != nil {
......@@ -476,26 +524,8 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano())
}
func testServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server {
return testServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
return helper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
// Write pure string
if data, ok := body.(string); ok {
log.Println("UPSTREAM", r.Method, r.URL, code)
......@@ -520,7 +550,15 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
}
func startWorkhorseServer(authBackend string) *httptest.Server {
return httptest.NewServer(newUpstream(authBackend, nil))
u := upstream.NewUpstream(
helper.URLMustParse(authBackend),
"",
"123",
testDocumentRoot,
false,
0,
)
return httptest.NewServer(u)
}
func runOrFail(t *testing.T, cmd *exec.Cmd) {
......@@ -532,7 +570,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) {
}
func gitOkBody(t *testing.T) interface{} {
return &authorizationResponse{
return &api.Response{
GL_ID: "user-123",
RepoPath: repoPath(t),
}
......@@ -545,7 +583,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} {
}
archivePath := path.Join(cwd, cacheDir, archiveName)
return &authorizationResponse{
return &api.Response{
RepoPath: repoPath(t),
ArchivePath: archivePath,
CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd",
......
package main
import (
"./internal/badgateway"
"./internal/helper"
"./internal/proxy"
"bytes"
"fmt"
"io"
......@@ -12,8 +15,12 @@ import (
"time"
)
func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy {
return proxy.NewProxy(helper.URLMustParse(url), "123", rt)
}
func TestProxyRequest(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatal("Expected POST request")
}
......@@ -39,15 +46,10 @@ func TestProxyRequest(t *testing.T) {
}
httpRequest.Header.Set("Custom-Header", "test")
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
}
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE")
newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 202)
helper.AssertResponseBody(t, w, "RESPONSE")
if w.Header().Get("Custom-Response-Header") != "test" {
t.Fatal("Expected custom response header")
......@@ -61,23 +63,14 @@ func TestProxyError(t *testing.T) {
}
httpRequest.Header.Set("Custom-Header", "test")
transport := proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream("http://localhost:655575/", &transport),
}
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575")
newProxy("http://localhost:655575/", nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502)
helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
}
func TestProxyReadTimeout(t *testing.T) {
ts := testServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) {
ts := helper.TestServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute)
})
......@@ -86,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err)
}
transport := &proxyRoundTripper{
transport: &http.Transport{
rt := &badgateway.RoundTripper{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
......@@ -98,19 +91,15 @@ func TestProxyReadTimeout(t *testing.T) {
},
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
p := newProxy(ts.URL, rt)
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers")
p.ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502)
helper.AssertResponseBody(t, w, "net/http: timeout awaiting response headers")
}
func TestProxyHandlerTimeout(t *testing.T) {
ts := testServerWithHandler(nil,
ts := helper.TestServerWithHandler(nil,
http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second)
}), time.Millisecond, "Request took too long").ServeHTTP,
......@@ -121,17 +110,8 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err)
}
transport := &proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long")
newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 503)
helper.AssertResponseBody(t, w, "Request took too long")
}
/*
The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend.
*/
package main
import (
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest)
type upstream struct {
httpClient *http.Client
httpProxy *httputil.ReverseProxy
authBackend string
relativeURLRoot string
}
type authorizationResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
}
// A gitRequest is an *http.Request decorated with attributes returned by the
// GitLab Rails application.
type gitRequest struct {
*http.Request
authorizationResponse
u *upstream
// This field contains the URL.Path stripped from RelativeUrlRoot
relativeURIPath string
}
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
gitlabURL, err := url.Parse(authBackend)
if err != nil {
log.Fatalln(err)
}
relativeURLRoot := gitlabURL.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
// If the relative URL is '/foobar' and we tell httputil.ReverseProxy to proxy
// to 'http://example.com/foobar' then we get a redirect loop, so we clear the
// Path field here.
gitlabURL.Path = ""
up := &upstream{
authBackend: authBackend,
httpClient: &http.Client{Transport: authTransport},
httpProxy: httputil.NewSingleHostReverseProxy(gitlabURL),
relativeURLRoot: relativeURLRoot,
}
up.httpProxy.Transport = authTransport
return up
}
func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
var g httpRoute
w := newLoggingResponseWriter(ow)
defer w.Log(r)
// Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
httpError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
httpError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := cleanURIPath(r.URL.Path)
if !strings.HasPrefix(URIPath, u.relativeURLRoot) && URIPath+"/" != u.relativeURLRoot {
httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Strip prefix and add "/"
// To match against non-relative URL
// Making it simpler for our matcher
relativeURIPath := cleanURIPath(strings.TrimPrefix(URIPath, u.relativeURLRoot))
// Look for a matching Git service
foundService := false
for _, g = range httpRoutes {
if g.method != "" && r.Method != g.method {
continue
}
if g.regex == nil || g.regex.MatchString(relativeURIPath) {
foundService = true
break
}
}
if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
httpError(&w, r, "Forbidden", http.StatusForbidden)
return
}
request := gitRequest{
Request: r,
relativeURIPath: relativeURIPath,
u: u,
}
g.handleFunc(&w, &request)
}
package main
import (
"flag"
"net/url"
)
type urlFlag struct {
*url.URL
}
func (u *urlFlag) Set(s string) error {
myURL, err := url.Parse(s)
if err != nil {
return err
}
u.URL = myURL
return nil
}
func URLFlag(name string, value *url.URL, usage string) **url.URL {
f := &urlFlag{value}
flag.CommandLine.Var(f, name, usage)
return &f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment