Commit a58ceca7 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'refactor-upstream' into 'master'

Refactor upstream

Rationale: the code has become a tangled mess of global variables and
types that hang together when they need not. For example: every HTTP
handler uses a 'gitRequest'?? I want to clean this up and see if I can
move some things into internal packages.

Apart from using internal packages we now use http.Handler where we can,
and fewer global variables.

See merge request !20
parents 64270207 9e7b612b
test/data
test/scratch
gitlab-workhorse gitlab-workhorse
test/public testdata/data
testdata/scratch
testdata/public
PREFIX=/usr/local PREFIX=/usr/local
VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S) VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S)
gitlab-workhorse: $(wildcard *.go) gitlab-workhorse: $(shell find . -name '*.go')
go build -ldflags "-X main.Version=${VERSION}" -o gitlab-workhorse go build -ldflags "-X main.Version=${VERSION}" -o gitlab-workhorse
install: gitlab-workhorse install: gitlab-workhorse
install gitlab-workhorse ${PREFIX}/bin/ install gitlab-workhorse ${PREFIX}/bin/
.PHONY: test .PHONY: test
test: test/data/group/test.git clean-workhorse gitlab-workhorse test: testdata/data/group/test.git clean-workhorse gitlab-workhorse
go fmt | awk '{ print "Please run go fmt"; exit 1 }' go fmt ./... | awk '{ print } END { if (NR > 0) { print "Please run go fmt"; exit 1 } }'
go test go test ./...
@echo SUCCESS
coverage: test/data/group/test.git coverage: testdata/data/group/test.git
go test -cover -coverprofile=test.coverage go test -cover -coverprofile=test.coverage
go tool cover -html=test.coverage -o coverage.html go tool cover -html=test.coverage -o coverage.html
rm -f test.coverage rm -f test.coverage
test/data/group/test.git: test/data testdata/data/group/test.git: testdata/data
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/group/test.git git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git $@
test/data: testdata/data:
mkdir -p test/data mkdir -p $@
.PHONY: clean .PHONY: clean
clean: clean-workhorse clean: clean-workhorse
rm -rf test/data test/scratch rm -rf testdata/data testdata/scratch
.PHONY: clean-workhorse .PHONY: clean-workhorse
clean-workhorse: clean-workhorse:
......
package main
func artifactsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(handleFunc, "/authorize")
}
package main package main
import ( import (
"./internal/api"
"./internal/helper"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -8,14 +10,14 @@ import ( ...@@ -8,14 +10,14 @@ import (
"testing" "testing"
) )
func okHandler(w http.ResponseWriter, r *gitRequest) { func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
w.WriteHeader(201) w.WriteHeader(201)
fmt.Fprint(w, "{\"status\":\"ok\"}") fmt.Fprint(w, "{\"status\":\"ok\"}")
} }
func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, authorizationResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder { func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(url, returnCode, authorizationResponse) ts := testAuthServer(url, returnCode, apiResponse)
defer ts.Close() defer ts.Close()
// Create http request // Create http request
...@@ -23,15 +25,11 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut ...@@ -23,15 +25,11 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil)
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
}
response := httptest.NewRecorder() response := httptest.NewRecorder()
preAuthorizeHandler(okHandler, suffix)(response, &request) a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
assertResponseCode(t, response, expectedCode) helper.AssertResponseCode(t, response, expectedCode)
return response return response
} }
...@@ -39,7 +37,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) { ...@@ -39,7 +37,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) {
runPreAuthorizeHandler( runPreAuthorizeHandler(
t, "/authorize", t, "/authorize",
regexp.MustCompile(`/authorize\z`), regexp.MustCompile(`/authorize\z`),
&authorizationResponse{}, &api.Response{},
200, 201) 200, 201)
} }
...@@ -47,7 +45,7 @@ func TestPreAuthorizeSuffix(t *testing.T) { ...@@ -47,7 +45,7 @@ func TestPreAuthorizeSuffix(t *testing.T) {
runPreAuthorizeHandler( runPreAuthorizeHandler(
t, "/different-authorize", t, "/different-authorize",
regexp.MustCompile(`/authorize\z`), regexp.MustCompile(`/authorize\z`),
&authorizationResponse{}, &api.Response{},
200, 404) 200, 404)
} }
......
package main
import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
if !*developmentMode {
http.NotFound(w, r.Request)
return
}
handler(w, r)
}
}
package main package api
import ( import (
"../badgateway"
"../helper"
"../proxy"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
) )
func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) { type API struct {
url := u.authBackend + "/" + strings.TrimPrefix(r.URL.RequestURI(), u.relativeURLRoot) + suffix Client *http.Client
authReq, err := http.NewRequest(r.Method, url, body) URL *url.URL
if err != nil { Version string
return nil, err }
func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API {
if roundTripper == nil {
roundTripper = badgateway.NewRoundTripper("", 0)
} }
// Forward all headers from our client to the auth backend. This includes return &API{
// HTTP Basic authentication credentials (the 'Authorization' header). Client: &http.Client{Transport: roundTripper},
for k, v := range r.Header { URL: myURL,
authReq.Header[k] = v Version: version,
}
}
type HandleFunc func(http.ResponseWriter, *http.Request, *Response)
type Response struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
}
// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// rebaseUrl is taken from reverseproxy.go:NewSingleHostReverseProxy
func rebaseUrl(url *url.URL, onto *url.URL, suffix string) *url.URL {
newUrl := *url
newUrl.Scheme = onto.Scheme
newUrl.Host = onto.Host
if suffix != "" {
newUrl.Path = singleJoiningSlash(url.Path, suffix)
}
if onto.RawQuery == "" || newUrl.RawQuery == "" {
newUrl.RawQuery = onto.RawQuery + newUrl.RawQuery
} else {
newUrl.RawQuery = onto.RawQuery + "&" + newUrl.RawQuery
}
return &newUrl
}
func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
authReq := &http.Request{
Method: r.Method,
URL: rebaseUrl(r.URL, api.URL, suffix),
Header: proxy.HeaderClone(r.Header),
}
if body != nil {
authReq.Body = ioutil.NopCloser(body)
} }
// Clean some headers when issuing a new request without body // Clean some headers when issuing a new request without body
...@@ -46,22 +125,22 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st ...@@ -46,22 +125,22 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st
authReq.Host = r.Host authReq.Host = r.Host
// Set a custom header for the request. This can be used in some // Set a custom header for the request. This can be used in some
// configurations (Passenger) to solve auth request routing problems. // configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", Version) authReq.Header.Set("Gitlab-Workhorse", api.Version)
return authReq, nil return authReq, nil
} }
func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc { func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix) authReq, err := api.newRequest(r, nil, suffix)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err)) helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
return return
} }
authResponse, err := r.u.httpClient.Do(authReq) authResponse, err := api.Client.Do(authReq)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err)) helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err))
return return
} }
defer authResponse.Body.Close() defer authResponse.Body.Close()
...@@ -85,11 +164,12 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan ...@@ -85,11 +164,12 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
return return
} }
a := &Response{}
// The auth backend validated the client request and told us additional // The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth // request metadata. We must extract this information from the auth
// response body. // response body.
if err := json.NewDecoder(authResponse.Body).Decode(&r.authorizationResponse); err != nil { if err := json.NewDecoder(authResponse.Body).Decode(a); err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err)) helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return return
} }
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now // Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
...@@ -104,6 +184,6 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan ...@@ -104,6 +184,6 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
} }
} }
handleFunc(w, r) h(w, r, a)
} })
} }
package main package badgateway
import ( import (
"../helper"
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"time"
) )
type proxyRoundTripper struct { // Values from http.DefaultTransport
transport http.RoundTripper var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
} }
func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { var DefaultTransport = &http.Transport{
res, err = p.transport.RoundTrip(r) Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
type RoundTripper struct {
Transport *http.Transport
}
func NewRoundTripper(socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
tr := *DefaultTransport
tr.ResponseHeaderTimeout = proxyHeadersTimeout
if socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", socket)
}
}
return &RoundTripper{Transport: &tr}
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = t.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this // httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error // RoundTrip function into 500 errors. But the most likely error
...@@ -21,7 +48,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err ...@@ -21,7 +48,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
// instead of 500s we catch the RoundTrip error here and inject a // instead of 500s we catch the RoundTrip error here and inject a
// 502 response. // 502 response.
if err != nil { if err != nil {
logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err)) helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{ res = &http.Response{
StatusCode: http.StatusBadGateway, StatusCode: http.StatusBadGateway,
...@@ -40,26 +67,3 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err ...@@ -40,26 +67,3 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
} }
return return
} }
func headerClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func proxyRequest(w http.ResponseWriter, r *gitRequest) {
// Clone request
req := *r.Request
req.Header = headerClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
r.u.httpProxy.ServeHTTP(&rw, &req)
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
In this file we handle 'git archive' downloads In this file we handle 'git archive' downloads
*/ */
package main package git
import ( import (
"../api"
"../helper"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -18,7 +20,10 @@ import ( ...@@ -18,7 +20,10 @@ import (
"time" "time"
) )
func handleGetArchive(w http.ResponseWriter, r *gitRequest) { func GetArchive(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handleGetArchive)
}
func handleGetArchive(w http.ResponseWriter, r *http.Request, a *api.Response) {
var format string var format string
urlPath := r.URL.Path urlPath := r.URL.Path
switch filepath.Base(urlPath) { switch filepath.Base(urlPath) {
...@@ -31,20 +36,20 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -31,20 +36,20 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
case "archive.tar.bz2": case "archive.tar.bz2":
format = "tar.bz2" format = "tar.bz2"
default: default:
fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath)) helper.Fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath))
return return
} }
archiveFilename := path.Base(r.ArchivePath) archiveFilename := path.Base(a.ArchivePath)
if cachedArchive, err := os.Open(r.ArchivePath); err == nil { if cachedArchive, err := os.Open(a.ArchivePath); err == nil {
defer cachedArchive.Close() defer cachedArchive.Close()
log.Printf("Serving cached file %q", r.ArchivePath) log.Printf("Serving cached file %q", a.ArchivePath)
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, format, archiveFilename)
// Even if somebody deleted the cachedArchive from disk since we opened // Even if somebody deleted the cachedArchive from disk since we opened
// the file, Unix file semantics guarantee we can still read from the // the file, Unix file semantics guarantee we can still read from the
// open file in this process. // open file in this process.
http.ServeContent(w, r.Request, "", time.Unix(0, 0), cachedArchive) http.ServeContent(w, r, "", time.Unix(0, 0), cachedArchive)
return return
} }
...@@ -52,9 +57,9 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -52,9 +57,9 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
// safe. We create the tempfile in the same directory as the final cached // safe. We create the tempfile in the same directory as the final cached
// archive we want to create so that we can use an atomic link(2) operation // archive we want to create so that we can use an atomic link(2) operation
// to finalize the cached archive. // to finalize the cached archive.
tempFile, err := prepareArchiveTempfile(path.Dir(r.ArchivePath), archiveFilename) tempFile, err := prepareArchiveTempfile(path.Dir(a.ArchivePath), archiveFilename)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err))
return return
} }
defer tempFile.Close() defer tempFile.Close()
...@@ -62,15 +67,15 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -62,15 +67,15 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
compressCmd, archiveFormat := parseArchiveFormat(format) compressCmd, archiveFormat := parseArchiveFormat(format)
archiveCmd := gitCommand("", "git", "--git-dir="+r.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+r.ArchivePrefix+"/", r.CommitId) archiveCmd := gitCommand("", "git", "--git-dir="+a.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+a.ArchivePrefix+"/", a.CommitId)
archiveStdout, err := archiveCmd.StdoutPipe() archiveStdout, err := archiveCmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err))
return return
} }
defer archiveStdout.Close() defer archiveStdout.Close()
if err := archiveCmd.Start(); err != nil { if err := archiveCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err))
return return
} }
defer cleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up defer cleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up
...@@ -84,13 +89,13 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -84,13 +89,13 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
stdout, err = compressCmd.StdoutPipe() stdout, err = compressCmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
if err := compressCmd.Start(); err != nil { if err := compressCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err))
return return
} }
defer cleanUpProcessGroup(compressCmd) defer cleanUpProcessGroup(compressCmd)
...@@ -105,22 +110,22 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -105,22 +110,22 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, format, archiveFilename)
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if _, err := io.Copy(w, archiveReader); err != nil { if _, err := io.Copy(w, archiveReader); err != nil {
logError(fmt.Errorf("handleGetArchive: read: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: read: %v", err))
return return
} }
if err := archiveCmd.Wait(); err != nil { if err := archiveCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
return return
} }
if compressCmd != nil { if compressCmd != nil {
if err := compressCmd.Wait(); err != nil { if err := compressCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: compressCmd: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
return return
} }
} }
if err := finalizeCachedArchive(tempFile, r.ArchivePath); err != nil { if err := finalizeCachedArchive(tempFile, a.ArchivePath); err != nil {
logError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
return return
} }
} }
......
/* package git
Miscellaneous helpers: logging, errors, subprocesses
*/
package main
import ( import (
"errors"
"fmt" "fmt"
"log"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path"
"syscall" "syscall"
) )
func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
logError(err)
}
func logError(err error) {
log.Printf("error: %v", err)
}
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
// Git subprocess helpers // Git subprocess helpers
func gitCommand(gl_id string, name string, args ...string) *exec.Cmd { func gitCommand(gl_id string, name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...) cmd := exec.Command(name, args...)
...@@ -64,57 +38,3 @@ func cleanUpProcessGroup(cmd *exec.Cmd) { ...@@ -64,57 +38,3 @@ func cleanUpProcessGroup(cmd *exec.Cmd) {
// reap our child process // reap our child process
cmd.Wait() cmd.Wait()
} }
func setNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
func openFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func cleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
In this file we handle the Git 'smart HTTP' protocol In this file we handle the Git 'smart HTTP' protocol
*/ */
package main package git
import ( import (
"../api"
"../helper"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -16,6 +18,14 @@ import ( ...@@ -16,6 +18,14 @@ import (
"strings" "strings"
) )
func GetInfoRefs(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handleGetInfoRefs)
}
func PostRPC(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handlePostRPC)
}
func looksLikeRepo(p string) bool { func looksLikeRepo(p string) bool {
// If /path/to/foo.git/objects exists then let's assume it is a valid Git // If /path/to/foo.git/objects exists then let's assume it is a valid Git
// repository. // repository.
...@@ -26,23 +36,23 @@ func looksLikeRepo(p string) bool { ...@@ -26,23 +36,23 @@ func looksLikeRepo(p string) bool {
return true return true
} }
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if r.RepoPath == "" { if a.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) helper.Fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return return
} }
if !looksLikeRepo(r.RepoPath) { if !looksLikeRepo(a.RepoPath) {
http.Error(w, "Not Found", 404) http.Error(w, "Not Found", 404)
return return
} }
handleFunc(w, r) handleFunc(w, r, a)
}, "") }, "")
} }
func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) {
rpc := r.URL.Query().Get("service") rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
...@@ -51,15 +61,15 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { ...@@ -51,15 +61,15 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
} }
// Prepare our Git subprocess // Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", r.RepoPath) cmd := gitCommand(a.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", a.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err))
return return
} }
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
...@@ -69,57 +79,57 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { ...@@ -69,57 +79,57 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil { if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
return return
} }
if err := pktFlush(w); err != nil { if err := pktFlush(w); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
return return
} }
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
return return
} }
} }
func handlePostRPC(w http.ResponseWriter, r *gitRequest) { func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
var err error var err error
// Get Git action from URL // Get Git action from URL
action := filepath.Base(r.URL.Path) action := filepath.Base(r.URL.Path)
if !(action == "git-upload-pack" || action == "git-receive-pack") { if !(action == "git-upload-pack" || action == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path)) helper.Fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path))
return return
} }
// Prepare our Git subprocess // Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(action), "--stateless-rpc", r.RepoPath) cmd := gitCommand(a.GL_ID, "git", subCommand(action), "--stateless-rpc", a.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err))
return return
} }
defer stdin.Close() defer stdin.Close()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err))
return return
} }
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Write the client request body to Git's standard input // Write the client request body to Git's standard input
if _, err := io.Copy(stdin, r.Body); err != nil { if _, err := io.Copy(stdin, r.Body); err != nil {
fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err))
return return
} }
// Signal to the Git subprocess that no more data is coming // Signal to the Git subprocess that no more data is coming
...@@ -136,11 +146,11 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest) { ...@@ -136,11 +146,11 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest) {
// This io.Copy may take a long time, both for Git push and pull. // This io.Copy may take a long time, both for Git push and pull.
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
return return
} }
} }
......
package helper
import (
"errors"
"log"
"net/http"
"net/url"
"os"
)
func Fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
LogError(err)
}
func LogError(err error) {
log.Printf("error: %v", err)
}
func SetNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
func URLMustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
log.Fatalf("urlMustParse: %q %v", s, err)
}
return u
}
func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
package main package helper
import ( import (
"fmt" "fmt"
...@@ -6,25 +6,25 @@ import ( ...@@ -6,25 +6,25 @@ import (
"time" "time"
) )
type loggingResponseWriter struct { type LoggingResponseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
status int status int
written int64 written int64
started time.Time started time.Time
} }
func newLoggingResponseWriter(rw http.ResponseWriter) loggingResponseWriter { func NewLoggingResponseWriter(rw http.ResponseWriter) LoggingResponseWriter {
return loggingResponseWriter{ return LoggingResponseWriter{
rw: rw, rw: rw,
started: time.Now(), started: time.Now(),
} }
} }
func (l *loggingResponseWriter) Header() http.Header { func (l *LoggingResponseWriter) Header() http.Header {
return l.rw.Header() return l.rw.Header()
} }
func (l *loggingResponseWriter) Write(data []byte) (n int, err error) { func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
if l.status == 0 { if l.status == 0 {
l.WriteHeader(http.StatusOK) l.WriteHeader(http.StatusOK)
} }
...@@ -33,7 +33,7 @@ func (l *loggingResponseWriter) Write(data []byte) (n int, err error) { ...@@ -33,7 +33,7 @@ func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
return return
} }
func (l *loggingResponseWriter) WriteHeader(status int) { func (l *LoggingResponseWriter) WriteHeader(status int) {
if l.status != 0 { if l.status != 0 {
return return
} }
...@@ -42,7 +42,7 @@ func (l *loggingResponseWriter) WriteHeader(status int) { ...@@ -42,7 +42,7 @@ func (l *loggingResponseWriter) WriteHeader(status int) {
l.rw.WriteHeader(status) l.rw.WriteHeader(status)
} }
func (l *loggingResponseWriter) Log(r *http.Request) { func (l *LoggingResponseWriter) Log(r *http.Request) {
duration := time.Since(l.started) duration := time.Since(l.started)
fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n", fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n",
r.Host, r.RemoteAddr, l.started, r.Host, r.RemoteAddr, l.started,
......
package main package helper
import ( import (
"log"
"net/http"
"net/http/httptest" "net/http/httptest"
"regexp"
"testing" "testing"
) )
func assertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) { func AssertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) {
if response.Code != expectedCode { if response.Code != expectedCode {
t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code) t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code)
} }
} }
func assertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) { func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
if response.Body.String() != expectedBody { if response.Body.String() != expectedBody {
t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String()) t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String())
} }
} }
func assertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) { func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) {
if response.Header().Get(header) != expectedValue { if response.Header().Get(header) != expectedValue {
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header)) t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
} }
} }
func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
/*
In this file we handle git lfs objects downloads and uploads
*/
package lfs
import (
"../api"
"../helper"
"../proxy"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
func PutStore(a *api.API, p *proxy.Proxy) http.Handler {
return lfsAuthorizeHandler(a, handleStoreLfsObject(p))
}
func lfsAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.StoreLFSPath == "" {
helper.Fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if a.LfsOid == "" {
helper.Fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(a.StoreLFSPath, 0700); err != nil {
helper.Fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
}
handleFunc(w, r, a)
}, "/authorize")
}
func handleStoreLfsObject(h http.Handler) api.HandleFunc {
return func(w http.ResponseWriter, r *http.Request, a *api.Response) {
file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid)
if err != nil {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return
}
file.Close()
if written != a.LfsSize {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", a.LfsSize, written))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != a.LfsOid {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", a.LfsOid, shaStr))
return
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// And proxy the request
h.ServeHTTP(w, r)
}
}
package proxy
import (
"../badgateway"
"net/http"
"net/http/httputil"
"net/url"
)
type Proxy struct {
Version string
reverseProxy *httputil.ReverseProxy
}
func NewProxy(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *Proxy {
p := Proxy{Version: version}
u := *myURL // Make a copy of p.URL
u.Path = ""
p.reverseProxy = httputil.NewSingleHostReverseProxy(&u)
if roundTripper != nil {
p.reverseProxy.Transport = roundTripper
} else {
p.reverseProxy.Transport = badgateway.NewRoundTripper("", 0)
}
return &p
}
func HeaderClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Clone request
req := *r
req.Header = HeaderClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", p.Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
p.reverseProxy.ServeHTTP(&rw, &req)
}
...@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the ...@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method. 'send_file' method.
*/ */
package main package proxy
import ( import (
"../helper"
"log" "log"
"net/http" "net/http"
) )
...@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) { ...@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
// Serve the file // Serve the file
log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI) log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI)
content, fi, err := openFile(file) content, fi, err := helper.OpenFile(file)
if err != nil { if err != nil {
http.NotFound(s.rw, s.req) http.NotFound(s.rw, s.req)
return return
......
package main package staticpages
import ( import (
"../helper"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path/filepath" "path/filepath"
) )
func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc { func (s *Static) DeployPage(handler http.Handler) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { deployPage := filepath.Join(s.DocumentRoot, "index.html")
deployPage := filepath.Join(*documentRoot, "index.html")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadFile(deployPage) data, err := ioutil.ReadFile(deployPage)
if err != nil { if err != nil {
handler(w, r) handler.ServeHTTP(w, r)
return return
} }
setNoCacheHeaders(w.Header()) helper.SetNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(data) w.Write(data)
} })
} }
package main package staticpages
import ( import (
"../helper"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -19,9 +20,10 @@ func TestIfNoDeployPageExist(t *testing.T) { ...@@ -19,9 +20,10 @@ func TestIfNoDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { st := &Static{dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, nil) })).ServeHTTP(w, nil)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -40,14 +42,15 @@ func TestIfDeployPageExist(t *testing.T) { ...@@ -40,14 +42,15 @@ func TestIfDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { st := &Static{dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, nil) })).ServeHTTP(w, nil)
if executed { if executed {
t.Error("The handler should not get executed") t.Error("The handler should not get executed")
} }
w.Flush() w.Flush()
assertResponseCode(t, w, 200) helper.AssertResponseCode(t, w, 200)
assertResponseBody(t, w, deployPage) helper.AssertResponseBody(t, w, deployPage)
} }
package main package staticpages
import ( import (
"../helper"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -12,7 +13,7 @@ type errorPageResponseWriter struct { ...@@ -12,7 +13,7 @@ type errorPageResponseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
status int status int
hijacked bool hijacked bool
path *string path string
} }
func (s *errorPageResponseWriter) Header() http.Header { func (s *errorPageResponseWriter) Header() http.Header {
...@@ -37,14 +38,14 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { ...@@ -37,14 +38,14 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status s.status = status
if 400 <= s.status && s.status <= 599 { if 400 <= s.status && s.status <= 599 {
errorPageFile := filepath.Join(*s.path, fmt.Sprintf("%d.html", s.status)) errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead // check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil { if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.hijacked = true s.hijacked = true
log.Printf("ErrorPage: serving predefined error page: %d", s.status) log.Printf("ErrorPage: serving predefined error page: %d", s.status)
setNoCacheHeaders(s.rw.Header()) helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8") s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.WriteHeader(s.status) s.rw.WriteHeader(s.status)
s.rw.Write(data) s.rw.Write(data)
...@@ -59,16 +60,16 @@ func (s *errorPageResponseWriter) Flush() { ...@@ -59,16 +60,16 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func handleRailsError(documentRoot *string, enabled *bool, handler serviceHandleFunc) serviceHandleFunc { func (st *Static) ErrorPages(enabled bool, handler http.Handler) http.Handler {
if !*enabled { if !enabled {
return handler return handler
} }
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{ rw := errorPageResponseWriter{
rw: w, rw: w,
path: documentRoot, path: st.DocumentRoot,
} }
defer rw.Flush() defer rw.Flush()
handler(&rw, r) handler.ServeHTTP(&rw, r)
} })
} }
package main package staticpages
import ( import (
"../helper"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
...@@ -21,16 +22,16 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -21,16 +22,16 @@ func TestIfErrorPageIsPresented(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
enabled := true
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, "Not Found") fmt.Fprint(w, "Not Found")
})(w, nil) })
st := &Static{dir}
st.ErrorPages(true, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
assertResponseCode(t, w, 404) helper.AssertResponseCode(t, w, 404)
assertResponseBody(t, w, errorPage) helper.AssertResponseBody(t, w, errorPage)
} }
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
...@@ -42,16 +43,16 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { ...@@ -42,16 +43,16 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
errorResponse := "ERROR" errorResponse := "ERROR"
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
enabled := true
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, errorResponse) fmt.Fprint(w, errorResponse)
})(w, nil) })
st := &Static{dir}
st.ErrorPages(true, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
assertResponseCode(t, w, 404) helper.AssertResponseCode(t, w, 404)
assertResponseBody(t, w, errorResponse) helper.AssertResponseBody(t, w, errorResponse)
} }
func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
...@@ -65,15 +66,14 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { ...@@ -65,15 +66,14 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
enabled := false
serverError := "Interesting Server Error" serverError := "Interesting Server Error"
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
})(w, nil) })
st := &Static{dir}
st.ErrorPages(false, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
helper.AssertResponseCode(t, w, 500)
assertResponseCode(t, w, 500) helper.AssertResponseBody(t, w, serverError)
assertResponseBody(t, w, serverError)
} }
package main package staticpages
import ( import (
"../helper"
"../urlprefix"
"log" "log"
"net/http" "net/http"
"os" "os"
...@@ -19,13 +21,13 @@ const ( ...@@ -19,13 +21,13 @@ const (
// BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists, // BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists,
// handleServeFile will serve foo/bar instead of passing the request // handleServeFile will serve foo/bar instead of passing the request
// upstream. // upstream.
func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serviceHandleFunc) serviceHandleFunc { func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
file := filepath.Join(*documentRoot, r.relativeURIPath) file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path))
// The filepath.Join does Clean traversing directories up // The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) { if !strings.HasPrefix(file, s.DocumentRoot) {
fail500(w, &os.PathError{ helper.Fail500(w, &os.PathError{
Op: "open", Op: "open",
Path: file, Path: file,
Err: os.ErrInvalid, Err: os.ErrInvalid,
...@@ -39,7 +41,7 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv ...@@ -39,7 +41,7 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
// Serve pre-gzipped assets // Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
content, fi, err = openFile(file + ".gz") content, fi, err = helper.OpenFile(file + ".gz")
if err == nil { if err == nil {
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
} }
...@@ -47,13 +49,13 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv ...@@ -47,13 +49,13 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
// If not found, open the original file // If not found, open the original file
if content == nil || err != nil { if content == nil || err != nil {
content, fi, err = openFile(file) content, fi, err = helper.OpenFile(file)
} }
if err != nil { if err != nil {
if notFoundHandler != nil { if notFoundHandler != nil {
notFoundHandler(w, r) notFoundHandler.ServeHTTP(w, r)
} else { } else {
http.NotFound(w, r.Request) http.NotFound(w, r)
} }
return return
} }
...@@ -68,6 +70,6 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv ...@@ -68,6 +70,6 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
} }
log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI) log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI)
http.ServeContent(w, r.Request, filepath.Base(file), fi.ModTime(), content) http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content)
} })
} }
package main package staticpages
import ( import (
"../helper"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io/ioutil" "io/ioutil"
...@@ -14,14 +15,11 @@ import ( ...@@ -14,14 +15,11 @@ import (
func TestServingNonExistingFile(t *testing.T) { func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 404) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 404)
} }
func TestServingDirectory(t *testing.T) { func TestServingDirectory(t *testing.T) {
...@@ -32,41 +30,31 @@ func TestServingDirectory(t *testing.T) { ...@@ -32,41 +30,31 @@ func TestServingDirectory(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/",
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 404) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 404)
} }
func TestServingMalformedUri(t *testing.T) { func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/../../../static/file",
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 500) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 404)
} }
func TestExecutingHandlerWhenNoFileFound(t *testing.T) { func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
executed := false executed := false
handleServeFile(&dir, CacheDisabled, func(w http.ResponseWriter, r *gitRequest) { st := &Static{dir}
executed = (r == request) st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
})(nil, request) executed = (r == httpRequest)
})).ServeHTTP(nil, httpRequest)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -80,17 +68,14 @@ func TestServingTheActualFile(t *testing.T) { ...@@ -80,17 +68,14 @@ func TestServingTheActualFile(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
fileContent := "STATIC" fileContent := "STATIC"
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 200) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 200)
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String()) t.Error("We should serve the file: ", w.Body.String())
} }
...@@ -104,10 +89,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -104,10 +89,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
if enableGzip { if enableGzip {
httpRequest.Header.Set("Accept-Encoding", "gzip, deflate") httpRequest.Header.Set("Accept-Encoding", "gzip, deflate")
...@@ -124,16 +105,17 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -124,16 +105,17 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 200) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 200)
if enableGzip { if enableGzip {
assertResponseHeader(t, w, "Content-Encoding", "gzip") helper.AssertResponseHeader(t, w, "Content-Encoding", "gzip")
if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 { if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 {
t.Error("We should serve the pregzipped file") t.Error("We should serve the pregzipped file")
} }
} else { } else {
assertResponseCode(t, w, 200) helper.AssertResponseCode(t, w, 200)
assertResponseHeader(t, w, "Content-Encoding", "") helper.AssertResponseHeader(t, w, "Content-Encoding", "")
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String()) t.Error("We should serve the file: ", w.Body.String())
} }
......
package staticpages
type Static struct {
DocumentRoot string
}
package upload
import (
"../api"
"net/http"
)
func Artifacts(myAPI *api.API, h http.Handler) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
r.Header.Set(tempPathHeader, a.TempPath)
handleFileUploads(h).ServeHTTP(w, r)
}, "/authorize")
}
package main package upload
import ( import (
"../helper"
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
...@@ -11,7 +12,9 @@ import ( ...@@ -11,7 +12,9 @@ import (
"os" "os"
) )
func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cleanup func(), err error) { const tempPathHeader = "Gitlab-Workhorse-Temp-Path"
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string) (cleanup func(), err error) {
// Create multipart reader // Create multipart reader
reader, err := r.MultipartReader() reader, err := r.MultipartReader()
if err != nil { if err != nil {
...@@ -47,12 +50,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -47,12 +50,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
// Copy form field // Copy form field
if filename := p.FileName(); filename != "" { if filename := p.FileName(); filename != "" {
// Create temporary directory where the uploaded file will be stored // Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(r.TempPath, 0700); err != nil { if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, err return cleanup, err
} }
// Create temporary file in path returned by Authorization filter // Create temporary file in path returned by Authorization filter
file, err := ioutil.TempFile(r.TempPath, "upload_") file, err := ioutil.TempFile(tempPath, "upload_")
if err != nil { if err != nil {
return cleanup, err return cleanup, err
} }
...@@ -83,39 +86,43 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -83,39 +86,43 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
return cleanup, nil return cleanup, nil
} }
func handleFileUploads(w http.ResponseWriter, r *gitRequest) { func handleFileUploads(h http.Handler) http.Handler {
if r.TempPath == "" { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fail500(w, errors.New("handleFileUploads: TempPath empty")) tempPath := r.Header.Get(tempPathHeader)
return if tempPath == "" {
} helper.Fail500(w, errors.New("handleFileUploads: TempPath empty"))
return
}
r.Header.Del(tempPathHeader)
var body bytes.Buffer var body bytes.Buffer
writer := multipart.NewWriter(&body) writer := multipart.NewWriter(&body)
defer writer.Close() defer writer.Close()
// Rewrite multipart form data // Rewrite multipart form data
cleanup, err := rewriteFormFilesFromMultipart(r, writer) cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath)
if err != nil { if err != nil {
if err == http.ErrNotMultipart { if err == http.ErrNotMultipart {
proxyRequest(w, r) h.ServeHTTP(w, r)
} else { } else {
fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err)) helper.Fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
}
return
} }
return
}
if cleanup != nil { if cleanup != nil {
defer cleanup() defer cleanup()
} }
// Close writer // Close writer
writer.Close() writer.Close()
// Hijack the request // Hijack the request
r.Body = ioutil.NopCloser(&body) r.Body = ioutil.NopCloser(&body)
r.ContentLength = int64(body.Len()) r.ContentLength = int64(body.Len())
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
// Proxy the request // Proxy the request
proxyRequest(w, r) h.ServeHTTP(w, r)
})
} }
package main package upload
import ( import (
"../helper"
"../proxy"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -14,19 +16,17 @@ import ( ...@@ -14,19 +16,17 @@ import (
"testing" "testing"
) )
var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{ request := &http.Request{}
authorizationResponse: authorizationResponse{ handleFileUploads(nilHandler).ServeHTTP(response, request)
TempPath: "", helper.AssertResponseCode(t, response, 500)
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 500)
} }
func TestUploadHandlerForwardingRawData(t *testing.T) { func TestUploadHandlerForwardingRawData(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" { if r.Method != "PATCH" {
t.Fatal("Expected PATCH request") t.Fatal("Expected PATCH request")
} }
...@@ -40,6 +40,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -40,6 +40,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
w.WriteHeader(202) w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
}) })
defer ts.Close()
httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil { if err != nil {
...@@ -53,15 +54,11 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -53,15 +54,11 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest, httpRequest.Header.Set(tempPathHeader, tempPath)
u: newUpstream(ts.URL, nil),
authorizationResponse: authorizationResponse{ handleFileUploads(proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)).ServeHTTP(response, httpRequest)
TempPath: tempPath, helper.AssertResponseCode(t, response, 202)
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
} }
...@@ -76,7 +73,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -76,7 +73,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
} }
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" { if r.Method != "PUT" {
t.Fatal("Expected PUT request") t.Fatal("Expected PUT request")
} }
...@@ -131,17 +128,11 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -131,17 +128,11 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Body = ioutil.NopCloser(&buffer) httpRequest.Body = ioutil.NopCloser(&buffer)
httpRequest.ContentLength = int64(buffer.Len()) httpRequest.ContentLength = int64(buffer.Len())
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpRequest.Header.Set(tempPathHeader, tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest, handleFileUploads(proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)).ServeHTTP(response, httpRequest)
u: newUpstream(ts.URL, nil), helper.AssertResponseCode(t, response, 202)
authorizationResponse: authorizationResponse{
TempPath: tempPath,
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) { if _, err := os.Stat(filePath); !os.IsNotExist(err) {
t.Fatal("expected the file to be deleted") t.Fatal("expected the file to be deleted")
......
package main package upstream
import ( import (
"../helper"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
...@@ -13,9 +14,9 @@ func TestDevelopmentModeEnabled(t *testing.T) { ...@@ -13,9 +14,9 @@ func TestDevelopmentModeEnabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, &gitRequest{Request: r}) })).ServeHTTP(w, r)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -28,11 +29,11 @@ func TestDevelopmentModeDisabled(t *testing.T) { ...@@ -28,11 +29,11 @@ func TestDevelopmentModeDisabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, &gitRequest{Request: r}) })).ServeHTTP(w, r)
if executed { if executed {
t.Error("The handler should not get executed") t.Error("The handler should not get executed")
} }
assertResponseCode(t, w, 404) helper.AssertResponseCode(t, w, 404)
} }
package main package upstream
import ( import (
"../helper"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
) )
func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func contentEncodingHandler(h http.Handler) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body io.ReadCloser var body io.ReadCloser
var err error var err error
...@@ -24,7 +25,7 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { ...@@ -24,7 +25,7 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
} }
if err != nil { if err != nil {
fail500(w, fmt.Errorf("contentEncodingHandler: %v", err)) helper.Fail500(w, fmt.Errorf("contentEncodingHandler: %v", err))
return return
} }
defer body.Close() defer body.Close()
...@@ -32,6 +33,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { ...@@ -32,6 +33,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
r.Body = body r.Body = body
r.Header.Del("Content-Encoding") r.Header.Del("Content-Encoding")
handleFunc(w, r) h.ServeHTTP(w, r)
} })
} }
package main package upstream
import ( import (
"../helper"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
...@@ -27,17 +28,16 @@ func TestGzipEncoding(t *testing.T) { ...@@ -27,17 +28,16 @@ func TestGzipEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "gzip") req.Header.Set("Content-Encoding", "gzip")
request := gitRequest{Request: req} contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
if _, ok := r.Body.(*gzip.Reader); !ok { if _, ok := r.Body.(*gzip.Reader); !ok {
t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body)) t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body))
} }
if r.Header.Get("Content-Encoding") != "" { if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted") t.Fatal("Content-Encoding should be deleted")
} }
})(resp, &request) })).ServeHTTP(resp, req)
assertResponseCode(t, resp, 200) helper.AssertResponseCode(t, resp, 200)
} }
func TestNoEncoding(t *testing.T) { func TestNoEncoding(t *testing.T) {
...@@ -52,17 +52,16 @@ func TestNoEncoding(t *testing.T) { ...@@ -52,17 +52,16 @@ func TestNoEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "") req.Header.Set("Content-Encoding", "")
request := gitRequest{Request: req} contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.Body != body { if r.Body != body {
t.Fatal("Expected the same body") t.Fatal("Expected the same body")
} }
if r.Header.Get("Content-Encoding") != "" { if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted") t.Fatal("Content-Encoding should be deleted")
} }
})(resp, &request) })).ServeHTTP(resp, req)
assertResponseCode(t, resp, 200) helper.AssertResponseCode(t, resp, 200)
} }
func TestInvalidEncoding(t *testing.T) { func TestInvalidEncoding(t *testing.T) {
...@@ -74,10 +73,9 @@ func TestInvalidEncoding(t *testing.T) { ...@@ -74,10 +73,9 @@ func TestInvalidEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "application/unknown") req.Header.Set("Content-Encoding", "application/unknown")
request := gitRequest{Request: req} contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
t.Fatal("it shouldn't be executed") t.Fatal("it shouldn't be executed")
})(resp, &request) })).ServeHTTP(resp, req)
assertResponseCode(t, resp, 500) helper.AssertResponseCode(t, resp, 500)
} }
package upstream
import "net/http"
func NotFoundUnless(pass bool, handler http.Handler) http.Handler {
if pass {
return handler
} else {
return http.HandlerFunc(http.NotFound)
}
}
package upstream
import (
apipkg "../api"
"../git"
"../lfs"
proxypkg "../proxy"
"../staticpages"
"../upload"
"net/http"
"regexp"
)
type route struct {
method string
regex *regexp.Regexp
handler http.Handler
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
// A project ID in an API request is either a number or two strings 'namespace/project'
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
func (u *Upstream) configureRoutes() {
api := apipkg.NewAPI(
u.Backend,
u.Version,
u.RoundTripper,
)
static := &staticpages.Static{u.DocumentRoot}
proxy := proxypkg.NewProxy(
u.Backend,
u.Version,
u.RoundTripper,
)
u.Routes = []route{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(api)},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(api, proxy)},
// Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(api)},
// Repository Archive API
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(api)},
// CI Artifacts API
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(api, proxy))},
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), proxy},
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix, staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode,
proxy,
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
route{"", regexp.MustCompile(`^/uploads/`), static.ErrorPages(u.DevelopmentMode, proxy)},
// Serve static files or forward the requests
route{"", nil,
static.ServeExisting(u.URLPrefix, staticpages.CacheDisabled,
static.DeployPage(
static.ErrorPages(u.DevelopmentMode,
proxy,
),
),
),
},
}
}
/*
The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend.
*/
package upstream
import (
"../badgateway"
"../helper"
"../urlprefix"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
var DefaultBackend = helper.URLMustParse("http://localhost:8080")
type Upstream struct {
Backend *url.URL
Version string
DocumentRoot string
DevelopmentMode bool
URLPrefix urlprefix.Prefix
Routes []route
RoundTripper *badgateway.RoundTripper
}
func NewUpstream(backend *url.URL, socket string, version string, documentRoot string, developmentMode bool, proxyHeadersTimeout time.Duration) *Upstream {
up := Upstream{
Backend: backend,
Version: version,
DocumentRoot: documentRoot,
DevelopmentMode: developmentMode,
RoundTripper: badgateway.NewRoundTripper(socket, proxyHeadersTimeout),
}
if backend == nil {
up.Backend = DefaultBackend
}
up.configureRoutes()
up.configureURLPrefix()
return &up
}
func (u *Upstream) configureURLPrefix() {
relativeURLRoot := u.Backend.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
u.URLPrefix = urlprefix.Prefix(relativeURLRoot)
}
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := helper.NewLoggingResponseWriter(ow)
defer w.Log(r)
// Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
helper.HTTPError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
helper.HTTPError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := urlprefix.CleanURIPath(r.URL.Path)
prefix := u.URLPrefix
if !prefix.Match(URIPath) {
helper.HTTPError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Look for a matching Git service
var ro route
foundService := false
for _, ro = range u.Routes {
if ro.method != "" && r.Method != ro.method {
continue
}
if ro.regex == nil || ro.regex.MatchString(prefix.Strip(URIPath)) {
foundService = true
break
}
}
if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
helper.HTTPError(&w, r, "Forbidden", http.StatusForbidden)
return
}
ro.handler.ServeHTTP(&w, r)
}
package urlprefix
import (
"path"
"strings"
)
type Prefix string
func (p Prefix) Strip(path string) string {
return CleanURIPath(strings.TrimPrefix(path, string(p)))
}
func (p Prefix) Match(path string) bool {
pre := string(p)
return strings.HasPrefix(path, pre) || path+"/" == pre
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func CleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
/*
In this file we handle git lfs objects downloads and uploads
*/
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
func lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.StoreLFSPath == "" {
fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if r.LfsOid == "" {
fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(r.StoreLFSPath, 0700); err != nil {
fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
}
handleFunc(w, r)
}, "/authorize")
}
func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
file, err := ioutil.TempFile(r.StoreLFSPath, r.LfsOid)
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return
}
file.Close()
if written != r.LfsSize {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", r.LfsSize, written))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != r.LfsOid {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", r.LfsOid, shaStr))
return
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// And proxy the request
proxyRequest(w, r)
}
...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type. ...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type.
package main package main
import ( import (
"./internal/upstream"
"flag" "flag"
"fmt" "fmt"
"log" "log"
...@@ -21,7 +22,6 @@ import ( ...@@ -21,7 +22,6 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"regexp"
"syscall" "syscall"
"time" "time"
) )
...@@ -33,95 +33,13 @@ var printVersion = flag.Bool("version", false, "Print version and exit") ...@@ -33,95 +33,13 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022") var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend") var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request") var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app") var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
type httpRoute struct {
method string
regex *regexp.Regexp
handleFunc serviceHandleFunc
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
// A project ID in an API request is either a number or two strings 'namespace/project'
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
var httpRoutes = [...]httpRoute{
// Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)},
// Repository Archive
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// Repository Archive API
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))},
// Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), proxyRequest},
httpRoute{"", regexp.MustCompile(ciAPIPattern), proxyRequest},
// Serve assets
httpRoute{"", regexp.MustCompile(`^/assets/`),
handleServeFile(documentRoot, CacheExpireMax,
handleDevelopmentMode(developmentMode,
handleDeployPage(documentRoot,
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
),
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through handleServeFile.
httpRoute{"", regexp.MustCompile(`^/uploads/`),
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
},
// Serve static files or forward the requests
httpRoute{"", nil,
handleServeFile(documentRoot, CacheDisabled,
handleDeployPage(documentRoot,
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
),
),
},
}
func main() { func main() {
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
...@@ -153,23 +71,6 @@ func main() { ...@@ -153,23 +71,6 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
// Create Proxy Transport
authTransport := http.DefaultTransport
if *authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", *authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := &proxyRoundTripper{transport: authTransport}
// The profiler will only be activated by HTTP requests. HTTP // The profiler will only be activated by HTTP requests. HTTP
// requests can only reach the profiler if we start a listener. So by // requests can only reach the profiler if we start a listener. So by
// having no profiler HTTP listener by default, the profiler is // having no profiler HTTP listener by default, the profiler is
...@@ -180,6 +81,14 @@ func main() { ...@@ -180,6 +81,14 @@ func main() {
}() }()
} }
upstream := newUpstream(*authBackend, proxyTransport) up := upstream.NewUpstream(
log.Fatal(http.Serve(listener, upstream)) *authBackend,
*authSocket,
Version,
*documentRoot,
*developmentMode,
*proxyHeadersTimeout,
)
log.Fatal(http.Serve(listener, up))
} }
package main package main
import ( import (
"./internal/api"
"./internal/helper"
"./internal/upstream"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
...@@ -18,9 +22,9 @@ import ( ...@@ -18,9 +22,9 @@ import (
"time" "time"
) )
const scratchDir = "test/scratch" const scratchDir = "testdata/scratch"
const testRepoRoot = "test/data" const testRepoRoot = "testdata/data"
const testDocumentRoot = "test/public" const testDocumentRoot = "testdata/public"
const testRepo = "group/test.git" const testRepo = "group/test.git"
const testProject = "group/test" const testProject = "group/test"
...@@ -325,7 +329,7 @@ func TestAllowedStaticFile(t *testing.T) { ...@@ -325,7 +329,7 @@ func TestAllowedStaticFile(t *testing.T) {
} }
proxied := false proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true proxied = true
w.WriteHeader(404) w.WriteHeader(404)
}) })
...@@ -339,21 +343,21 @@ func TestAllowedStaticFile(t *testing.T) { ...@@ -339,21 +343,21 @@ func TestAllowedStaticFile(t *testing.T) {
} { } {
resp, err := http.Get(ws.URL + resource) resp, err := http.Get(ws.URL + resource)
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
defer resp.Body.Close() defer resp.Body.Close()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if _, err := io.Copy(buf, resp.Body); err != nil { if _, err := io.Copy(buf, resp.Body); err != nil {
t.Fatal(err) t.Error(err)
} }
if buf.String() != content { if buf.String() != content {
t.Fatalf("GET %q: Expected %q, got %q", resource, content, buf.String()) t.Errorf("GET %q: Expected %q, got %q", resource, content, buf.String())
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("GET %q: expected 200, got %d", resource, resp.StatusCode) t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
} }
if proxied { if proxied {
t.Fatalf("GET %q: should not have made it to backend", resource) t.Errorf("GET %q: should not have made it to backend", resource)
} }
} }
} }
...@@ -365,7 +369,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) { ...@@ -365,7 +369,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) {
} }
proxied := false proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true proxied = true
w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path) w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path)
w.WriteHeader(200) w.WriteHeader(200)
...@@ -406,7 +410,7 @@ func TestDeniedPublicUploadsFile(t *testing.T) { ...@@ -406,7 +410,7 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
} }
proxied := false proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
proxied = true proxied = true
w.WriteHeader(404) w.WriteHeader(404)
}) })
...@@ -439,6 +443,50 @@ func TestDeniedPublicUploadsFile(t *testing.T) { ...@@ -439,6 +443,50 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
} }
} }
func TestArtifactsUpload(t *testing.T) {
reqBody := &bytes.Buffer{}
writer := multipart.NewWriter(reqBody)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
writer.Close()
ts := helper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
if len(r.MultipartForm.Value) != 2 { // 1 file name, 1 file path
t.Error("Expected to receive exactly 2 values")
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
w.WriteHeader(200)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, writer.FormDataContentType(), reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func setupStaticFile(fpath, content string) error { func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
...@@ -476,26 +524,8 @@ func newBranch() string { ...@@ -476,26 +524,8 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano()) return fmt.Sprintf("branch-%d", time.Now().UnixNano())
} }
func testServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server { func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server {
return testServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) { return helper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
// Write pure string // Write pure string
if data, ok := body.(string); ok { if data, ok := body.(string); ok {
log.Println("UPSTREAM", r.Method, r.URL, code) log.Println("UPSTREAM", r.Method, r.URL, code)
...@@ -520,7 +550,15 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -520,7 +550,15 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
return httptest.NewServer(newUpstream(authBackend, nil)) u := upstream.NewUpstream(
helper.URLMustParse(authBackend),
"",
"123",
testDocumentRoot,
false,
0,
)
return httptest.NewServer(u)
} }
func runOrFail(t *testing.T, cmd *exec.Cmd) { func runOrFail(t *testing.T, cmd *exec.Cmd) {
...@@ -532,7 +570,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) { ...@@ -532,7 +570,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) {
} }
func gitOkBody(t *testing.T) interface{} { func gitOkBody(t *testing.T) interface{} {
return &authorizationResponse{ return &api.Response{
GL_ID: "user-123", GL_ID: "user-123",
RepoPath: repoPath(t), RepoPath: repoPath(t),
} }
...@@ -545,7 +583,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} { ...@@ -545,7 +583,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} {
} }
archivePath := path.Join(cwd, cacheDir, archiveName) archivePath := path.Join(cwd, cacheDir, archiveName)
return &authorizationResponse{ return &api.Response{
RepoPath: repoPath(t), RepoPath: repoPath(t),
ArchivePath: archivePath, ArchivePath: archivePath,
CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd", CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd",
......
package main package main
import ( import (
"./internal/badgateway"
"./internal/helper"
"./internal/proxy"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -12,8 +15,12 @@ import ( ...@@ -12,8 +15,12 @@ import (
"time" "time"
) )
func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy {
return proxy.NewProxy(helper.URLMustParse(url), "123", rt)
}
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
t.Fatal("Expected POST request") t.Fatal("Expected POST request")
} }
...@@ -39,15 +46,10 @@ func TestProxyRequest(t *testing.T) { ...@@ -39,15 +46,10 @@ func TestProxyRequest(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 202) helper.AssertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE") helper.AssertResponseBody(t, w, "RESPONSE")
if w.Header().Get("Custom-Response-Header") != "test" { if w.Header().Get("Custom-Response-Header") != "test" {
t.Fatal("Expected custom response header") t.Fatal("Expected custom response header")
...@@ -61,23 +63,14 @@ func TestProxyError(t *testing.T) { ...@@ -61,23 +63,14 @@ func TestProxyError(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
transport := proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream("http://localhost:655575/", &transport),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) newProxy("http://localhost:655575/", nil).ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) helper.AssertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575") helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
func TestProxyReadTimeout(t *testing.T) { func TestProxyReadTimeout(t *testing.T) {
ts := testServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute) time.Sleep(time.Minute)
}) })
...@@ -86,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -86,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxyRoundTripper{ rt := &badgateway.RoundTripper{
transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
...@@ -98,19 +91,15 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -98,19 +91,15 @@ func TestProxyReadTimeout(t *testing.T) {
}, },
} }
request := gitRequest{ p := newProxy(ts.URL, rt)
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) p.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) helper.AssertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers") helper.AssertResponseBody(t, w, "net/http: timeout awaiting response headers")
} }
func TestProxyHandlerTimeout(t *testing.T) { func TestProxyHandlerTimeout(t *testing.T) {
ts := testServerWithHandler(nil, ts := helper.TestServerWithHandler(nil,
http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second) time.Sleep(time.Second)
}), time.Millisecond, "Request took too long").ServeHTTP, }), time.Millisecond, "Request took too long").ServeHTTP,
...@@ -121,17 +110,8 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -121,17 +110,8 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 503) helper.AssertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long") helper.AssertResponseBody(t, w, "Request took too long")
} }
/*
The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend.
*/
package main
import (
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest)
type upstream struct {
httpClient *http.Client
httpProxy *httputil.ReverseProxy
authBackend string
relativeURLRoot string
}
type authorizationResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
}
// A gitRequest is an *http.Request decorated with attributes returned by the
// GitLab Rails application.
type gitRequest struct {
*http.Request
authorizationResponse
u *upstream
// This field contains the URL.Path stripped from RelativeUrlRoot
relativeURIPath string
}
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
gitlabURL, err := url.Parse(authBackend)
if err != nil {
log.Fatalln(err)
}
relativeURLRoot := gitlabURL.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
// If the relative URL is '/foobar' and we tell httputil.ReverseProxy to proxy
// to 'http://example.com/foobar' then we get a redirect loop, so we clear the
// Path field here.
gitlabURL.Path = ""
up := &upstream{
authBackend: authBackend,
httpClient: &http.Client{Transport: authTransport},
httpProxy: httputil.NewSingleHostReverseProxy(gitlabURL),
relativeURLRoot: relativeURLRoot,
}
up.httpProxy.Transport = authTransport
return up
}
func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
var g httpRoute
w := newLoggingResponseWriter(ow)
defer w.Log(r)
// Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
httpError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
httpError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := cleanURIPath(r.URL.Path)
if !strings.HasPrefix(URIPath, u.relativeURLRoot) && URIPath+"/" != u.relativeURLRoot {
httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Strip prefix and add "/"
// To match against non-relative URL
// Making it simpler for our matcher
relativeURIPath := cleanURIPath(strings.TrimPrefix(URIPath, u.relativeURLRoot))
// Look for a matching Git service
foundService := false
for _, g = range httpRoutes {
if g.method != "" && r.Method != g.method {
continue
}
if g.regex == nil || g.regex.MatchString(relativeURIPath) {
foundService = true
break
}
}
if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
httpError(&w, r, "Forbidden", http.StatusForbidden)
return
}
request := gitRequest{
Request: r,
relativeURIPath: relativeURIPath,
u: u,
}
g.handleFunc(&w, &request)
}
package main
import (
"flag"
"net/url"
)
type urlFlag struct {
*url.URL
}
func (u *urlFlag) Set(s string) error {
myURL, err := url.Parse(s)
if err != nil {
return err
}
u.URL = myURL
return nil
}
func URLFlag(name string, value *url.URL, usage string) **url.URL {
f := &urlFlag{value}
flag.CommandLine.Var(f, name, usage)
return &f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment