Commit 1a7009e4 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'passthrough' into 'master'

Make GitLab Workhorse passthrough to Unicorn

It works :)


See merge request !12
parents 8f551c2e 9c051cb0
...@@ -8,12 +8,17 @@ install: gitlab-workhorse ...@@ -8,12 +8,17 @@ install: gitlab-workhorse
install gitlab-workhorse ${PREFIX}/bin/ install gitlab-workhorse ${PREFIX}/bin/
.PHONY: test .PHONY: test
test: test/data/test.git clean-workhorse gitlab-workhorse test: test/data/group/test.git clean-workhorse gitlab-workhorse
go fmt | awk '{ print "Please run go fmt"; exit 1 }' go fmt | awk '{ print "Please run go fmt"; exit 1 }'
go test go test
test/data/test.git: test/data coverage: test/data/group/test.git
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/test.git go test -cover -coverprofile=test.coverage
go tool cover -html=test.coverage -o coverage.html
rm -f test.coverage
test/data/group/test.git: test/data
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/group/test.git
test/data: test/data:
mkdir -p test/data mkdir -p test/data
......
...@@ -2,13 +2,55 @@ package main ...@@ -2,13 +2,55 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings" "strings"
) )
func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
url := u.authBackend + r.URL.RequestURI() + suffix
authReq, err := http.NewRequest(r.Method, url, body)
if err != nil {
return nil, err
}
// Forward all headers from our client to the auth backend. This includes
// HTTP Basic authentication credentials (the 'Authorization' header).
for k, v := range r.Header {
authReq.Header[k] = v
}
// Clean some headers when issuing a new request without body
if body == nil {
authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding")
authReq.Header.Del("Content-Length")
authReq.Header.Del("Content-Disposition")
authReq.Header.Del("Accept-Encoding")
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
authReq.Header.Del("Transfer-Encoding")
authReq.Header.Del("Connection")
authReq.Header.Del("Keep-Alive")
authReq.Header.Del("Proxy-Authenticate")
authReq.Header.Del("Proxy-Authorization")
authReq.Header.Del("Te")
authReq.Header.Del("Trailers")
authReq.Header.Del("Upgrade")
}
// Also forward the Host header, which is excluded from the Header map by the http libary.
// This allows the Host header received by the backend to be consistent with other
// requests not going through gitlab-workhorse.
authReq.Host = r.Host
// Set a custom header for the request. This can be used in some
// configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", Version)
return authReq, nil
}
func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc { func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix) authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix)
...@@ -65,19 +107,3 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan ...@@ -65,19 +107,3 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
handleFunc(w, r) handleFunc(w, r)
} }
} }
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return
}
if !looksLikeRepo(r.RepoPath) {
http.Error(w, "Not Found", 404)
return
}
handleFunc(w, r)
}, "")
}
package main
import (
"io/ioutil"
"net/http"
"path/filepath"
)
func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
deployPage := filepath.Join(*documentRoot, "index.html")
data, err := ioutil.ReadFile(deployPage)
if err != nil {
handler(w, r)
return
}
setNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(data)
}
}
package main
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func TestIfNoDeployPageExist(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
w := httptest.NewRecorder()
executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) {
executed = true
})(w, nil)
if !executed {
t.Error("The handler should get executed")
}
}
func TestIfDeployPageExist(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
deployPage := "DEPLOY"
ioutil.WriteFile(filepath.Join(dir, "index.html"), []byte(deployPage), 0600)
w := httptest.NewRecorder()
executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) {
executed = true
})(w, nil)
if executed {
t.Error("The handler should not get executed")
}
w.Flush()
assertResponseCode(t, w, 200)
assertResponseBody(t, w, deployPage)
}
package main
import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
if !*developmentMode {
http.NotFound(w, r.Request)
return
}
handler(w, r)
}
}
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestDevelopmentModeEnabled(t *testing.T) {
developmentMode := true
r, _ := http.NewRequest("GET", "/something", nil)
w := httptest.NewRecorder()
executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) {
executed = true
})(w, &gitRequest{Request: r})
if !executed {
t.Error("The handler should get executed")
}
}
func TestDevelopmentModeDisabled(t *testing.T) {
developmentMode := false
r, _ := http.NewRequest("GET", "/something", nil)
w := httptest.NewRecorder()
executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) {
executed = true
})(w, &gitRequest{Request: r})
if executed {
t.Error("The handler should not get executed")
}
assertResponseCode(t, w, 404)
}
package main
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"path/filepath"
)
type errorPageResponseWriter struct {
rw http.ResponseWriter
status int
hijacked bool
path *string
}
func (s *errorPageResponseWriter) Header() http.Header {
return s.rw.Header()
}
func (s *errorPageResponseWriter) Write(data []byte) (n int, err error) {
if s.status == 0 {
s.WriteHeader(http.StatusOK)
}
if s.hijacked {
return 0, nil
}
return s.rw.Write(data)
}
func (s *errorPageResponseWriter) WriteHeader(status int) {
if s.status != 0 {
return
}
s.status = status
if 400 <= s.status && s.status <= 599 {
errorPageFile := filepath.Join(*s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.hijacked = true
log.Printf("ErrorPage: serving predefined error page: %d", s.status)
setNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
return
}
}
s.rw.WriteHeader(status)
}
func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK)
}
func handleRailsError(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
rw := errorPageResponseWriter{
rw: w,
path: documentRoot,
}
defer rw.Flush()
handler(&rw, r)
}
}
package main
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func TestIfErrorPageIsPresented(t *testing.T) {
dir, err := ioutil.TempDir("", "error_page")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
errorPage := "ERROR"
ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder()
handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404)
fmt.Fprint(w, "Not Found")
})(w, nil)
w.Flush()
assertResponseCode(t, w, 404)
assertResponseBody(t, w, errorPage)
}
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
dir, err := ioutil.TempDir("", "error_page")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
w := httptest.NewRecorder()
errorResponse := "ERROR"
handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404)
fmt.Fprint(w, errorResponse)
})(w, nil)
w.Flush()
assertResponseCode(t, w, 404)
assertResponseBody(t, w, errorResponse)
}
...@@ -5,13 +5,43 @@ In this file we handle the Git 'smart HTTP' protocol ...@@ -5,13 +5,43 @@ In this file we handle the Git 'smart HTTP' protocol
package main package main
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"os"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
) )
func looksLikeRepo(p string) bool {
// If /path/to/foo.git/objects exists then let's assume it is a valid Git
// repository.
if _, err := os.Stat(path.Join(p, "objects")); err != nil {
log.Print(err)
return false
}
return true
}
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return
}
if !looksLikeRepo(r.RepoPath) {
http.Error(w, "Not Found", 404)
return
}
handleFunc(w, r)
}, "")
}
func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
rpc := r.URL.Query().Get("service") rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
......
...@@ -5,23 +5,16 @@ Miscellaneous helpers: logging, errors, subprocesses ...@@ -5,23 +5,16 @@ Miscellaneous helpers: logging, errors, subprocesses
package main package main
import ( import (
"errors"
"fmt" "fmt"
"io"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"net/url"
"os" "os"
"os/exec" "os/exec"
"strings" "path"
"syscall" "syscall"
) )
func fail400(w http.ResponseWriter, err error) {
http.Error(w, "Bad request", 400)
logError(err)
}
func fail500(w http.ResponseWriter, err error) { func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500) http.Error(w, "Internal server error", 500)
logError(err) logError(err)
...@@ -31,6 +24,15 @@ func logError(err error) { ...@@ -31,6 +24,15 @@ func logError(err error) {
log.Printf("error: %v", err) log.Printf("error: %v", err)
} }
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
// Git subprocess helpers // Git subprocess helpers
func gitCommand(gl_id string, name string, args ...string) *exec.Cmd { func gitCommand(gl_id string, name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...) cmd := exec.Command(name, args...)
...@@ -63,20 +65,56 @@ func cleanUpProcessGroup(cmd *exec.Cmd) { ...@@ -63,20 +65,56 @@ func cleanUpProcessGroup(cmd *exec.Cmd) {
cmd.Wait() cmd.Wait()
} }
func forwardResponseToClient(w http.ResponseWriter, r *http.Response) { func setNoCacheHeaders(header http.Header) {
log.Printf("PROXY:%s %q %d", r.Request.Method, r.Request.URL, r.StatusCode) header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
for k, v := range r.Header { func openFile(path string) (file *os.File, fi os.FileInfo, err error) {
w.Header()[k] = v file, err = os.Open(path)
if err != nil {
return
} }
w.WriteHeader(r.StatusCode) defer func() {
io.Copy(w, r.Body) if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
} }
func setHttpPostForm(r *http.Request, values url.Values) { // Borrowed from: net/http/server.go
dataBuffer := strings.NewReader(values.Encode()) // Return the canonical path for p, eliminating . and .. elements.
r.Body = ioutil.NopCloser(dataBuffer) func cleanURIPath(p string) string {
r.ContentLength = int64(dataBuffer.Len()) if p == "" {
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
} }
...@@ -10,3 +10,15 @@ func assertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expec ...@@ -10,3 +10,15 @@ func assertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expec
t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code) t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code)
} }
} }
func assertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
if response.Body.String() != expectedBody {
t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String())
}
}
func assertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) {
if response.Header().Get(header) != expectedValue {
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
}
}
...@@ -5,6 +5,7 @@ In this file we handle git lfs objects downloads and uploads ...@@ -5,6 +5,7 @@ In this file we handle git lfs objects downloads and uploads
package main package main
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
...@@ -67,20 +68,12 @@ func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) { ...@@ -67,20 +68,12 @@ func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", r.LfsOid, shaStr)) fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", r.LfsOid, shaStr))
return return
} }
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
storeReq, err := r.u.newUpstreamRequest(r.Request, nil, "")
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: newUpstreamRequest: %v", err))
return
}
storeResponse, err := r.u.httpClient.Do(storeReq) // Inject header and body
if err != nil { r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
fail500(w, fmt.Errorf("handleStoreLfsObject: do %v: %v", storeReq.URL.Path, err)) r.Body = ioutil.NopCloser(&bytes.Buffer{})
return r.ContentLength = 0
}
defer storeResponse.Body.Close()
forwardResponseToClient(w, storeResponse) // And proxy the request
proxyRequest(w, r)
} }
package main
import (
"fmt"
"net/http"
"time"
)
type loggingResponseWriter struct {
rw http.ResponseWriter
status int
written int64
started time.Time
}
func newLoggingResponseWriter(rw http.ResponseWriter) loggingResponseWriter {
return loggingResponseWriter{
rw: rw,
started: time.Now(),
}
}
func (l *loggingResponseWriter) Header() http.Header {
return l.rw.Header()
}
func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
if l.status == 0 {
l.WriteHeader(http.StatusOK)
}
n, err = l.rw.Write(data)
l.written += int64(n)
return
}
func (l *loggingResponseWriter) WriteHeader(status int) {
if l.status != 0 {
return
}
l.status = status
l.rw.WriteHeader(status)
}
func (l *loggingResponseWriter) Log(r *http.Request) {
duration := time.Since(l.started)
fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n",
r.Host, r.RemoteAddr, l.started,
fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
l.status, l.written, r.Referer(), r.UserAgent(), duration.Seconds(),
)
}
...@@ -21,20 +21,97 @@ import ( ...@@ -21,20 +21,97 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"regexp"
"syscall" "syscall"
"time" "time"
) )
// Current version of GitLab Workhorse
var Version = "(unknown version)" // Set at build time in the Makefile var Version = "(unknown version)" // Set at build time in the Makefile
var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var relativeURLRoot = flag.String("relativeURLRoot", "/", "GitLab relative URL root")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
type httpRoute struct {
method string
regex *regexp.Regexp
handleFunc serviceHandleFunc
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
const projectsAPIPattern = `^/api/v3/projects/[^/]+/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
var httpRoutes = [...]httpRoute{
// Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)},
// Repository Archive
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// Repository Archive API
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))},
// Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), proxyRequest},
httpRoute{"", regexp.MustCompile(ciAPIPattern), proxyRequest},
// Serve assets
httpRoute{"", regexp.MustCompile(`^/assets/`),
handleServeFile(documentRoot, CacheExpireMax,
handleDevelopmentMode(developmentMode,
handleDeployPage(documentRoot,
handleRailsError(documentRoot,
proxyRequest,
),
),
),
),
},
// Serve static files or forward the requests
httpRoute{"", nil,
handleServeFile(documentRoot, CacheDisabled,
handleDeployPage(documentRoot,
handleRailsError(documentRoot,
proxyRequest,
),
),
),
},
}
func main() { func main() {
printVersion := flag.Bool("version", false, "Print version and exit")
listenAddr := flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
listenNetwork := flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
listenUmask := flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
authBackend := flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend")
authSocket := flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
pprofListenAddr := flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
fmt.Fprintf(os.Stderr, "\n %s [OPTIONS]\n\nOptions:\n", os.Args[0]) fmt.Fprintf(os.Stderr, "\n %s [OPTIONS]\n\nOptions:\n", os.Args[0])
...@@ -65,7 +142,8 @@ func main() { ...@@ -65,7 +142,8 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
var authTransport http.RoundTripper // Create Proxy Transport
authTransport := http.DefaultTransport
if *authSocket != "" { if *authSocket != "" {
dialer := &net.Dialer{ dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport // The values below are taken from http.DefaultTransport
...@@ -76,8 +154,10 @@ func main() { ...@@ -76,8 +154,10 @@ func main() {
Dial: func(_, _ string) (net.Conn, error) { Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", *authSocket) return dialer.Dial("unix", *authSocket)
}, },
ResponseHeaderTimeout: *responseHeadersTimeout,
} }
} }
proxyTransport := &proxyRoundTripper{transport: authTransport}
// The profiler will only be activated by HTTP requests. HTTP // The profiler will only be activated by HTTP requests. HTTP
// requests can only reach the profiler if we start a listener. So by // requests can only reach the profiler if we start a listener. So by
...@@ -89,9 +169,7 @@ func main() { ...@@ -89,9 +169,7 @@ func main() {
}() }()
} }
// Because net/http/pprof installs itself in the DefaultServeMux upstream := newUpstream(*authBackend, proxyTransport)
// we create a fresh one for the Git server. upstream.SetRelativeURLRoot(*relativeURLRoot)
serveMux := http.NewServeMux() log.Fatal(http.Serve(listener, upstream))
serveMux.Handle("/", newUpstream(*authBackend, authTransport))
log.Fatal(http.Serve(listener, serveMux))
} }
...@@ -18,8 +18,8 @@ import ( ...@@ -18,8 +18,8 @@ import (
const scratchDir = "test/scratch" const scratchDir = "test/scratch"
const testRepoRoot = "test/data" const testRepoRoot = "test/data"
const testRepo = "test.git" const testRepo = "group/test.git"
const testProject = "test" const testProject = "group/test"
var checkoutDir = path.Join(scratchDir, "test") var checkoutDir = path.Join(scratchDir, "test")
var cacheDir = path.Join(scratchDir, "cache") var cacheDir = path.Join(scratchDir, "cache")
...@@ -276,7 +276,6 @@ func preparePushRepo(t *testing.T) { ...@@ -276,7 +276,6 @@ func preparePushRepo(t *testing.T) {
} }
cloneCmd := exec.Command("git", "clone", path.Join(testRepoRoot, testRepo), checkoutDir) cloneCmd := exec.Command("git", "clone", path.Join(testRepoRoot, testRepo), checkoutDir)
runOrFail(t, cloneCmd) runOrFail(t, cloneCmd)
return
} }
func newBranch() string { func newBranch() string {
...@@ -367,89 +366,3 @@ func repoPath(t *testing.T) string { ...@@ -367,89 +366,3 @@ func repoPath(t *testing.T) string {
} }
return path.Join(cwd, testRepoRoot, testRepo) return path.Join(cwd, testRepoRoot, testRepo)
} }
func TestDeniedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
deniedXSendfileDownload(t, contentFilename, url)
}
func TestAllowedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
allowedXSendfileDownload(t, contentFilename, url)
}
func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
contentPath := path.Join(cacheDir, contentFilename)
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("X-Sendfile", contentPath)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.Header().Set("Content-Type", fmt.Sprintf(`application/octet-stream`))
w.WriteHeader(200)
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatal(err)
}
contentBytes := []byte("content")
if err := ioutil.WriteFile(contentPath, contentBytes, 0644); err != nil {
t.Fatal(err)
}
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, contentBytes) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.WriteHeader(200)
fmt.Fprint(w, "Denied")
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, []byte("Denied")) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
) )
func proxyRequest(w http.ResponseWriter, r *gitRequest) { type proxyRoundTripper struct {
upRequest, err := r.u.newUpstreamRequest(r.Request, r.Body, "") transport http.RoundTripper
}
func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = p.transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
// is that the Rails app is not responding, in which case users
// and administrators expect to see a 502 error. To show 502s
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil { if err != nil {
fail500(w, fmt.Errorf("proxyRequest: newUpstreamRequest: %v", err)) logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
return
res = &http.Response{
StatusCode: http.StatusBadGateway,
Status: http.StatusText(http.StatusBadGateway),
Request: r,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Proto: r.Proto,
Header: make(http.Header),
Trailer: make(http.Header),
Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())),
}
res.Header.Set("Content-Type", "text/plain")
err = nil
} }
return
}
upResponse, err := r.u.httpClient.Do(upRequest) func headerClone(h http.Header) http.Header {
if err != nil { h2 := make(http.Header, len(h))
fail500(w, fmt.Errorf("proxyRequest: do %v: %v", upRequest.URL.Path, err)) for k, vv := range h {
return vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
} }
defer upResponse.Body.Close() return h2
}
func proxyRequest(w http.ResponseWriter, r *gitRequest) {
// Clone request
req := *r.Request
req.Header = headerClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
forwardResponseToClient(w, upResponse) r.u.httpProxy.ServeHTTP(&rw, &req)
} }
...@@ -4,10 +4,12 @@ import ( ...@@ -4,10 +4,12 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp" "regexp"
"testing" "testing"
"time"
) )
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
...@@ -42,15 +44,94 @@ func TestProxyRequest(t *testing.T) { ...@@ -42,15 +44,94 @@ func TestProxyRequest(t *testing.T) {
u: newUpstream(ts.URL, nil), u: newUpstream(ts.URL, nil),
} }
response := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(response, &request) proxyRequest(w, &request)
assertResponseCode(t, response, 202) assertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE")
if response.Body.String() != "RESPONSE" { if w.Header().Get("Custom-Response-Header") != "test" {
t.Fatal("Expected RESPONSE in response body:", response.Body.String()) t.Fatal("Expected custom response header")
} }
}
if response.Header().Get("Custom-Response-Header") != "test" { func TestProxyError(t *testing.T) {
t.Fatal("Expected custom response header") httpRequest, err := http.NewRequest("POST", "/url/path", bytes.NewBufferString("REQUEST"))
if err != nil {
t.Fatal(err)
}
httpRequest.Header.Set("Custom-Header", "test")
transport := proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream("http://localhost:655575/", &transport),
}
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575")
}
func TestProxyReadTimeout(t *testing.T) {
ts := testServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute)
})
httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil)
if err != nil {
t.Fatal(err)
} }
transport := &proxyRoundTripper{
transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: time.Millisecond,
},
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers")
}
func TestProxyHandlerTimeout(t *testing.T) {
ts := testServerWithHandler(nil,
http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second)
}), time.Millisecond, "Request took too long").ServeHTTP,
)
httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil)
if err != nil {
t.Fatal(err)
}
transport := &proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder()
proxyRequest(w, &request)
assertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long")
} }
/*
The xSendFile middleware transparently sends static files in HTTP responses
via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method.
*/
package main
import (
"log"
"net/http"
)
type sendFileResponseWriter struct {
rw http.ResponseWriter
status int
hijacked bool
req *http.Request
}
func newSendFileResponseWriter(rw http.ResponseWriter, req *http.Request) sendFileResponseWriter {
s := sendFileResponseWriter{
rw: rw,
req: req,
}
req.Header.Set("X-Sendfile-Type", "X-Sendfile")
return s
}
func (s *sendFileResponseWriter) Header() http.Header {
return s.rw.Header()
}
func (s *sendFileResponseWriter) Write(data []byte) (n int, err error) {
if s.status == 0 {
s.WriteHeader(http.StatusOK)
}
if s.hijacked {
return
}
return s.rw.Write(data)
}
func (s *sendFileResponseWriter) WriteHeader(status int) {
if s.status != 0 {
return
}
s.status = status
// Check X-Sendfile header
file := s.Header().Get("X-Sendfile")
s.Header().Del("X-Sendfile")
// If file is empty or status is not 200 pass through header
if file == "" || s.status != http.StatusOK {
s.rw.WriteHeader(s.status)
return
}
// Mark this connection as hijacked
s.hijacked = true
// Serve the file
log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI)
content, fi, err := openFile(file)
if err != nil {
http.NotFound(s.rw, s.req)
return
}
defer content.Close()
http.ServeContent(s.rw, s.req, "", fi.ModTime(), content)
}
func (s *sendFileResponseWriter) Flush() {
s.WriteHeader(http.StatusOK)
}
package main
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path"
"testing"
)
func TestDeniedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
deniedXSendfileDownload(t, contentFilename, url)
}
func TestAllowedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
allowedXSendfileDownload(t, contentFilename, url)
}
func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
contentPath := path.Join(cacheDir, contentFilename)
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("X-Sendfile", contentPath)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.Header().Set("Content-Type", fmt.Sprintf(`application/octet-stream`))
w.WriteHeader(200)
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatal(err)
}
contentBytes := []byte("content")
if err := ioutil.WriteFile(contentPath, contentBytes, 0644); err != nil {
t.Fatal(err)
}
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, contentBytes) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.WriteHeader(200)
fmt.Fprint(w, "Denied")
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, []byte("Denied")) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
package main
import (
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
type CacheMode int
const (
CacheDisabled CacheMode = iota
CacheExpireMax
)
func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
file := filepath.Join(*documentRoot, r.relativeURIPath)
// The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) {
fail500(w, &os.PathError{
Op: "open",
Path: file,
Err: os.ErrInvalid,
})
return
}
var content *os.File
var fi os.FileInfo
var err error
// Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
content, fi, err = openFile(file + ".gz")
if err == nil {
w.Header().Set("Content-Encoding", "gzip")
}
}
// If not found, open the original file
if content == nil || err != nil {
content, fi, err = openFile(file)
}
if err != nil {
if notFoundHandler != nil {
notFoundHandler(w, r)
} else {
http.NotFound(w, r.Request)
}
return
}
defer content.Close()
switch cache {
case CacheExpireMax:
// Cache statically served files for 1 year
cacheUntil := time.Now().AddDate(1, 0, 0).Format(http.TimeFormat)
w.Header().Set("Cache-Control", "public")
w.Header().Set("Expires", cacheUntil)
}
log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI)
http.ServeContent(w, r.Request, filepath.Base(file), fi.ModTime(), content)
}
}
package main
import (
"bytes"
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 404)
}
func TestServingDirectory(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/",
}
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 404)
}
func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/../../../static/file",
}
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 500)
}
func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
executed := false
handleServeFile(&dir, CacheDisabled, func(w http.ResponseWriter, r *gitRequest) {
executed = (r == request)
})(nil, request)
if !executed {
t.Error("The handler should get executed")
}
}
func TestServingTheActualFile(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
fileContent := "STATIC"
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 200)
if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String())
}
}
func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
if enableGzip {
httpRequest.Header.Set("Accept-Encoding", "gzip, deflate")
}
fileContent := "STATIC"
var fileGzipContent bytes.Buffer
fileGzip := gzip.NewWriter(&fileGzipContent)
fileGzip.Write([]byte(fileContent))
fileGzip.Close()
ioutil.WriteFile(filepath.Join(dir, "file.gz"), fileGzipContent.Bytes(), 0600)
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request)
assertResponseCode(t, w, 200)
if enableGzip {
assertResponseHeader(t, w, "Content-Encoding", "gzip")
if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 {
t.Error("We should serve the pregzipped file")
}
} else {
assertResponseCode(t, w, 200)
assertResponseHeader(t, w, "Content-Encoding", "")
if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String())
}
}
}
func TestServingThePregzippedFile(t *testing.T) {
testServingThePregzippedFile(t, true)
}
func TestServingThePregzippedFileWithoutEncoding(t *testing.T) {
testServingThePregzippedFile(t, false)
}
...@@ -111,25 +111,11 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) { ...@@ -111,25 +111,11 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) {
// Close writer // Close writer
writer.Close() writer.Close()
// Create request // Hijack the request
upstreamRequest, err := r.u.newUpstreamRequest(r.Request, nil, "") r.Body = ioutil.NopCloser(&body)
if err != nil { r.ContentLength = int64(body.Len())
fail500(w, fmt.Errorf("handleFileUploads: newUpstreamRequest: %v", err)) r.Header.Set("Content-Type", writer.FormDataContentType())
return
}
// Set multipart form data
upstreamRequest.Body = ioutil.NopCloser(&body)
upstreamRequest.ContentLength = int64(body.Len())
upstreamRequest.Header.Set("Content-Type", writer.FormDataContentType())
// Forward request to backend
upstreamResponse, err := r.u.httpClient.Do(upstreamRequest)
if err != nil {
fail500(w, fmt.Errorf("handleFileUploads: do request %v: %v", upstreamRequest.URL.Path, err))
return
}
defer upstreamResponse.Body.Close()
forwardResponseToClient(w, upstreamResponse) // Proxy the request
proxyRequest(w, r)
} }
...@@ -7,25 +7,21 @@ In this file we handle request routing and interaction with the authBackend. ...@@ -7,25 +7,21 @@ In this file we handle request routing and interaction with the authBackend.
package main package main
import ( import (
"io" "fmt"
"log" "log"
"net/http" "net/http"
"os" "net/http/httputil"
"path" "net/url"
"regexp" "strings"
) )
type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest) type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest)
type upstream struct { type upstream struct {
httpClient *http.Client httpClient *http.Client
authBackend string httpProxy *httputil.ReverseProxy
} authBackend string
relativeURLRoot string
type gitService struct {
method string
regex *regexp.Regexp
handleFunc serviceHandleFunc
} }
type authorizationResponse struct { type authorizationResponse struct {
...@@ -56,50 +52,79 @@ type authorizationResponse struct { ...@@ -56,50 +52,79 @@ type authorizationResponse struct {
TempPath string TempPath string
} }
// A gitReqest is an *http.Request decorated with attributes returned by the // A gitRequest is an *http.Request decorated with attributes returned by the
// GitLab Rails application. // GitLab Rails application.
type gitRequest struct { type gitRequest struct {
*http.Request *http.Request
authorizationResponse authorizationResponse
u *upstream u *upstream
}
// Routing table // This field contains the URL.Path stripped from RelativeUrlRoot
var gitServices = [...]gitService{ relativeURIPath string
gitService{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
gitService{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
gitService{"POST", regexp.MustCompile(`/git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
gitService{"GET", regexp.MustCompile(`/repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
gitService{"GET", regexp.MustCompile(`/repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
gitService{"GET", regexp.MustCompile(`/repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
gitService{"GET", regexp.MustCompile(`/repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
gitService{"GET", regexp.MustCompile(`/repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
gitService{"GET", regexp.MustCompile(`/uploads/`), handleSendFile},
// Git LFS
gitService{"PUT", regexp.MustCompile(`/gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)},
gitService{"GET", regexp.MustCompile(`/gitlab-lfs/objects/([0-9a-f]{64})\z`), handleSendFile},
// CI artifacts
gitService{"GET", regexp.MustCompile(`/builds/download\z`), handleSendFile},
gitService{"GET", regexp.MustCompile(`/ci/api/v1/builds/[0-9]+/artifacts\z`), handleSendFile},
gitService{"POST", regexp.MustCompile(`/ci/api/v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))},
gitService{"DELETE", regexp.MustCompile(`/ci/api/v1/builds/[0-9]+/artifacts\z`), proxyRequest},
} }
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream { func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
return &upstream{&http.Client{Transport: authTransport}, authBackend} u, err := url.Parse(authBackend)
if err != nil {
log.Fatalln(err)
}
up := &upstream{
authBackend: authBackend,
httpClient: &http.Client{Transport: authTransport},
httpProxy: httputil.NewSingleHostReverseProxy(u),
relativeURLRoot: "/",
}
up.httpProxy.Transport = authTransport
return up
}
func (u *upstream) SetRelativeURLRoot(relativeURLRoot string) {
u.relativeURLRoot = relativeURLRoot
if !strings.HasSuffix(u.relativeURLRoot, "/") {
u.relativeURLRoot += "/"
}
} }
func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
var g gitService var g httpRoute
w := newLoggingResponseWriter(ow)
defer w.Log(r)
log.Printf("%s %q", r.Method, r.URL) // Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
httpError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
httpError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := cleanURIPath(r.URL.Path)
if !strings.HasPrefix(URIPath, u.relativeURLRoot) {
httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Strip prefix and add "/"
// To match against non-relative URL
// Making it simpler for our matcher
relativeURIPath := cleanURIPath(strings.TrimPrefix(URIPath, u.relativeURLRoot))
// Look for a matching Git service // Look for a matching Git service
foundService := false foundService := false
for _, g = range gitServices { for _, g = range httpRoutes {
if r.Method == g.method && g.regex.MatchString(r.URL.Path) { if g.method != "" && r.Method != g.method {
continue
}
if g.regex == nil || g.regex.MatchString(relativeURIPath) {
foundService = true foundService = true
break break
} }
...@@ -107,57 +132,15 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -107,57 +132,15 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !foundService { if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt // The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found. // says we must return 403 if no matching service is found.
http.Error(w, "Forbidden", 403) httpError(&w, r, "Forbidden", http.StatusForbidden)
return return
} }
request := gitRequest{ request := gitRequest{
Request: r, Request: r,
u: u, relativeURIPath: relativeURIPath,
} u: u,
g.handleFunc(w, &request)
}
func looksLikeRepo(p string) bool {
// If /path/to/foo.git/objects exists then let's assume it is a valid Git
// repository.
if _, err := os.Stat(path.Join(p, "objects")); err != nil {
log.Print(err)
return false
} }
return true
}
func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
url := u.authBackend + r.URL.RequestURI() + suffix
authReq, err := http.NewRequest(r.Method, url, body)
if err != nil {
return nil, err
}
// Forward all headers from our client to the auth backend. This includes
// HTTP Basic authentication credentials (the 'Authorization' header).
for k, v := range r.Header {
authReq.Header[k] = v
}
// Clean some headers when issuing a new request without body
if body == nil {
authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding")
authReq.Header.Del("Content-Length")
authReq.Header.Del("Content-Disposition")
authReq.Header.Del("Accept-Encoding")
authReq.Header.Del("Transfer-Encoding")
}
// Also forward the Host header, which is excluded from the Header map by the http libary.
// This allows the Host header received by the backend to be consistent with other
// requests not going through gitlab-workhorse.
authReq.Host = r.Host
// Set a custom header for the request. This can be used in some
// configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", Version)
return authReq, nil g.handleFunc(&w, &request)
} }
/*
The xSendFile middleware transparently sends static files in HTTP responses
via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method.
*/
package main
import (
"fmt"
"io"
"log"
"net/http"
"os"
)
func handleSendFile(w http.ResponseWriter, r *gitRequest) {
upRequest, err := r.u.newUpstreamRequest(r.Request, r.Body, "")
if err != nil {
fail500(w, fmt.Errorf("handleSendFile: newUpstreamRequest: %v", err))
return
}
upRequest.Header.Set("X-Sendfile-Type", "X-Sendfile")
upResponse, err := r.u.httpClient.Do(upRequest)
r.Body.Close()
if err != nil {
fail500(w, fmt.Errorf("handleSendfile: do upstream request: %v", err))
return
}
defer upResponse.Body.Close()
// Get X-Sendfile
sendfile := upResponse.Header.Get("X-Sendfile")
upResponse.Header.Del("X-Sendfile")
// Copy headers from Rails upResponse
for k, v := range upResponse.Header {
w.Header()[k] = v
}
// Use accelerated file serving
if sendfile == "" {
// Copy request body otherwise
w.WriteHeader(upResponse.StatusCode)
// Copy body from Rails upResponse
if _, err := io.Copy(w, upResponse.Body); err != nil {
fail500(w, fmt.Errorf("handleSendFile: copy upstream response: %v", err))
}
return
}
log.Printf("Serving file %q", sendfile)
upResponse.Body.Close()
content, err := os.Open(sendfile)
if err != nil {
fail500(w, fmt.Errorf("handleSendile: open sendfile: %v", err))
return
}
defer content.Close()
fi, err := content.Stat()
if err != nil {
fail500(w, fmt.Errorf("handleSendfile: get mtime: %v", err))
return
}
http.ServeContent(w, r.Request, "", fi.ModTime(), content)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment