Commit d186cf65 authored by Kamil Trzcinski's avatar Kamil Trzcinski

Make the regexp to match full URL and added tests

parent e9616610
...@@ -8,12 +8,17 @@ install: gitlab-workhorse ...@@ -8,12 +8,17 @@ install: gitlab-workhorse
install gitlab-workhorse ${PREFIX}/bin/ install gitlab-workhorse ${PREFIX}/bin/
.PHONY: test .PHONY: test
test: test/data/test.git clean-workhorse gitlab-workhorse test: test/data/group/test.git clean-workhorse gitlab-workhorse
go fmt | awk '{ print "Please run go fmt"; exit 1 }' go fmt | awk '{ print "Please run go fmt"; exit 1 }'
go test go test
test/data/test.git: test/data coverage: test/data/group/test.git
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/test.git go test -cover -coverprofile=test.coverage
go tool cover -html=test.coverage -o coverage.html
rm -f test.coverage
test/data/group/test.git: test/data
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/group/test.git
test/data: test/data:
mkdir -p test/data mkdir -p test/data
......
...@@ -3,11 +3,13 @@ package main ...@@ -3,11 +3,13 @@ package main
import ( import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path/filepath"
) )
func handleDeployPage(deployPage *string, handler serviceHandleFunc) serviceHandleFunc { func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
data, err := ioutil.ReadFile(*deployPage) deployPage := filepath.Join(*documentRoot, "index.html")
data, err := ioutil.ReadFile(deployPage)
if err != nil { if err != nil {
handler(w, r) handler(w, r)
return return
......
package main
import (
"testing"
"io/ioutil"
"os"
"net/http"
"net/http/httptest"
"path/filepath"
)
func TestIfNoDeployPageExist(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
w := httptest.NewRecorder()
executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) {
executed = true
})(w, nil)
if !executed {
t.Error("The handler should get executed")
}
}
func TestIfDeployPageExist(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
deployPage := "DEPLOY"
ioutil.WriteFile(filepath.Join(dir, "index.html"), []byte(deployPage), 0600)
w := httptest.NewRecorder()
executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) {
executed = true
})(w, nil)
if executed {
t.Error("The handler should not get executed")
}
w.Flush()
if w.Code != 200 {
t.Error("Page should be 200")
}
if w.Body.String() != deployPage {
t.Error("Page should be deploy: ", w.Body.String())
}
}
...@@ -9,10 +9,10 @@ import ( ...@@ -9,10 +9,10 @@ import (
) )
type errorPageResponseWriter struct { type errorPageResponseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
status int status int
hijacked bool hijacked bool
errorPages *string path *string
} }
func (s *errorPageResponseWriter) Header() http.Header { func (s *errorPageResponseWriter) Header() http.Header {
...@@ -37,7 +37,7 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { ...@@ -37,7 +37,7 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status s.status = status
if 400 <= s.status && s.status <= 599 { if 400 <= s.status && s.status <= 599 {
errorPageFile := filepath.Join(*errorPages, fmt.Sprintf("%d.html", s.status)) errorPageFile := filepath.Join(*s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead // check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil { if data, err := ioutil.ReadFile(errorPageFile); err == nil {
...@@ -59,11 +59,11 @@ func (s *errorPageResponseWriter) Flush() { ...@@ -59,11 +59,11 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func handleRailsError(errorPages *string, handler serviceHandleFunc) serviceHandleFunc { func handleRailsError(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
rw := errorPageResponseWriter{ rw := errorPageResponseWriter{
rw: w, rw: w,
errorPages: errorPages, path: documentRoot,
} }
defer rw.Flush() defer rw.Flush()
handler(&rw, r) handler(&rw, r)
......
package main
import (
"testing"
"io/ioutil"
"path/filepath"
"net/http/httptest"
"os"
"net/http"
"fmt"
)
func TestIfErrorPageIsPresented(t *testing.T) {
dir, err := ioutil.TempDir("", "error_page")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
errorPage := "ERROR"
ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder()
handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404)
fmt.Fprint(w, "Not Found")
})(w, nil)
w.Flush()
if w.Code != 404 {
t.Error("Page should be 404")
}
if w.Body.String() != errorPage {
t.Error("Page should be custom error page: ", w.Body.String())
}
}
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
dir, err := ioutil.TempDir("", "error_page")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
w := httptest.NewRecorder()
errorResponse := "ERROR"
handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404)
fmt.Fprint(w, errorResponse)
})(w, nil)
w.Flush()
if w.Code != 404 {
t.Error("Page should be 400")
}
if w.Body.String() != errorResponse {
t.Error("Page should be response error: ", w.Body.String())
}
}
...@@ -14,11 +14,6 @@ import ( ...@@ -14,11 +14,6 @@ import (
"syscall" "syscall"
) )
func fail400(w http.ResponseWriter, err error) {
http.Error(w, "Bad request", 400)
logError(err)
}
func fail500(w http.ResponseWriter, err error) { func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500) http.Error(w, "Internal server error", 500)
logError(err) logError(err)
......
...@@ -22,7 +22,6 @@ import ( ...@@ -22,7 +22,6 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"regexp" "regexp"
"strings"
"syscall" "syscall"
"time" "time"
) )
...@@ -38,8 +37,7 @@ var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to ...@@ -38,8 +37,7 @@ var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var relativeUrlRoot = flag.String("relativeUrlRoot", "/", "GitLab relative URL root") var relativeUrlRoot = flag.String("relativeUrlRoot", "/", "GitLab relative URL root")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var deployPage = flag.String("deployPage", "public/index.html", "Path to file that will always be served if present") var proxyTimeout = flag.Duration("proxyTimeout", 5 * time.Minute, "Proxy request timeout")
var errorPages = flag.String("errorPages", "public/index.html", "The folder containing custom error pages, ie.: 500.html")
type httpRoute struct { type httpRoute struct {
method string method string
...@@ -51,17 +49,23 @@ type httpRoute struct { ...@@ -51,17 +49,23 @@ type httpRoute struct {
// We match against URI not containing the relativeUrlRoot: // We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP // see upstream.ServeHTTP
var httpRoutes = [...]httpRoute{ var httpRoutes = [...]httpRoute{
httpRoute{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)}, // Git Clone
httpRoute{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"GET", regexp.MustCompile(`^/[^/]+/[^/]+\.git/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(`/git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(`/[^/]+/[^/]+\.git/git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"GET", regexp.MustCompile(`/repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"POST", regexp.MustCompile(`/[^/]+/[^/]+\.git/git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"GET", regexp.MustCompile(`/repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(`/[^/]+/[^/]+/repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(`/[^/]+/[^/]+/repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(`/[^/]+/[^/]+/repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(`/[^/]+/[^/]+/repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/[^/]+/[^/]+/repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/api/v3/projects/[^/]+/repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/api/v3/projects/[^/]+/repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/api/v3/projects/[^/]+/repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/api/v3/projects/[^/]+/repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(`/api/v3/projects/[^/]+/repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// Git LFS // Git LFS
httpRoute{"PUT", regexp.MustCompile(`/gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)}, httpRoute{"PUT", regexp.MustCompile(`/[^/]+/[^/]+\.git/gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)},
// CI artifacts // CI artifacts
httpRoute{"POST", regexp.MustCompile(`^/ci/api/v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))}, httpRoute{"POST", regexp.MustCompile(`^/ci/api/v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))},
...@@ -72,8 +76,8 @@ var httpRoutes = [...]httpRoute{ ...@@ -72,8 +76,8 @@ var httpRoutes = [...]httpRoute{
// Serve static files and forward otherwise // Serve static files and forward otherwise
httpRoute{"", nil, handleServeFile(documentRoot, httpRoute{"", nil, handleServeFile(documentRoot,
handleDeployPage(deployPage, handleDeployPage(documentRoot,
handleRailsError(errorPages, handleRailsError(documentRoot,
proxyRequest, proxyRequest,
)))}, )))},
} }
...@@ -92,10 +96,6 @@ func main() { ...@@ -92,10 +96,6 @@ func main() {
os.Exit(0) os.Exit(0)
} }
if !strings.HasSuffix(*relativeUrlRoot, "/") {
*relativeUrlRoot += "/"
}
log.Printf("Starting %s", version) log.Printf("Starting %s", version)
// Good housekeeping for Unix sockets: unlink before binding // Good housekeeping for Unix sockets: unlink before binding
...@@ -137,9 +137,13 @@ func main() { ...@@ -137,9 +137,13 @@ func main() {
}() }()
} }
upstream := newUpstream(*authBackend, authTransport)
upstream.SetRelativeUrlRoot(*relativeUrlRoot)
upstream.SetProxyTimeout(*proxyTimeout)
// Because net/http/pprof installs itself in the DefaultServeMux // Because net/http/pprof installs itself in the DefaultServeMux
// we create a fresh one for the Git server. // we create a fresh one for the Git server.
serveMux := http.NewServeMux() serveMux := http.NewServeMux()
serveMux.Handle(*relativeUrlRoot, newUpstream(*authBackend, authTransport)) serveMux.Handle(upstream.relativeUrlRoot, upstream)
log.Fatal(http.Serve(listener, serveMux)) log.Fatal(http.Serve(listener, serveMux))
} }
...@@ -18,8 +18,8 @@ import ( ...@@ -18,8 +18,8 @@ import (
const scratchDir = "test/scratch" const scratchDir = "test/scratch"
const testRepoRoot = "test/data" const testRepoRoot = "test/data"
const testRepo = "test.git" const testRepo = "group/test.git"
const testProject = "test" const testProject = "group/test"
var checkoutDir = path.Join(scratchDir, "test") var checkoutDir = path.Join(scratchDir, "test")
var cacheDir = path.Join(scratchDir, "cache") var cacheDir = path.Join(scratchDir, "cache")
...@@ -276,7 +276,6 @@ func preparePushRepo(t *testing.T) { ...@@ -276,7 +276,6 @@ func preparePushRepo(t *testing.T) {
} }
cloneCmd := exec.Command("git", "clone", path.Join(testRepoRoot, testRepo), checkoutDir) cloneCmd := exec.Command("git", "clone", path.Join(testRepoRoot, testRepo), checkoutDir)
runOrFail(t, cloneCmd) runOrFail(t, cloneCmd)
return
} }
func newBranch() string { func newBranch() string {
...@@ -368,88 +367,3 @@ func repoPath(t *testing.T) string { ...@@ -368,88 +367,3 @@ func repoPath(t *testing.T) string {
return path.Join(cwd, testRepoRoot, testRepo) return path.Join(cwd, testRepoRoot, testRepo)
} }
func TestDeniedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
deniedXSendfileDownload(t, contentFilename, url)
}
func TestAllowedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
allowedXSendfileDownload(t, contentFilename, url)
}
func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
contentPath := path.Join(cacheDir, contentFilename)
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("X-Sendfile", contentPath)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.Header().Set("Content-Type", fmt.Sprintf(`application/octet-stream`))
w.WriteHeader(200)
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatal(err)
}
contentBytes := []byte("content")
if err := ioutil.WriteFile(contentPath, contentBytes, 0644); err != nil {
t.Fatal(err)
}
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, contentBytes) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.WriteHeader(200)
fmt.Fprint(w, "Denied")
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, []byte("Denied")) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
...@@ -23,5 +23,6 @@ func proxyRequest(w http.ResponseWriter, r *gitRequest) { ...@@ -23,5 +23,6 @@ func proxyRequest(w http.ResponseWriter, r *gitRequest) {
req.Header.Set("Gitlab-Workhorse", Version) req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req) rw := newSendFileResponseWriter(w, &req)
defer rw.Flush() defer rw.Flush()
r.u.httpProxy.ServeHTTP(&rw, &req) r.u.httpProxy.ServeHTTP(&rw, &req)
} }
package main
import (
"testing"
"fmt"
"path"
"net/http/httptest"
"net/http"
"log"
"os"
"io/ioutil"
"os/exec"
"bytes"
)
func TestDeniedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
deniedXSendfileDownload(t, contentFilename, url)
}
func TestAllowedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
prepareDownloadDir(t)
allowedXSendfileDownload(t, contentFilename, url)
}
func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
contentPath := path.Join(cacheDir, contentFilename)
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("X-Sendfile", contentPath)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.Header().Set("Content-Type", fmt.Sprintf(`application/octet-stream`))
w.WriteHeader(200)
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatal(err)
}
contentBytes := []byte("content")
if err := ioutil.WriteFile(contentPath, contentBytes, 0644); err != nil {
t.Fatal(err)
}
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, contentBytes) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
prepareDownloadDir(t)
// Prepare test server and backend
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("UPSTREAM", r.Method, r.URL)
if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" {
t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType)
}
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename))
w.WriteHeader(200)
fmt.Fprint(w, "Denied")
}))
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath))
downloadCmd.Dir = scratchDir
runOrFail(t, downloadCmd)
actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename))
if err != nil {
t.Fatal(err)
}
if bytes.Compare(actual, []byte("Denied")) != 0 {
t.Fatal("Unexpected file contents in download")
}
}
package main package main
import ( import (
"fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
...@@ -15,7 +14,11 @@ func handleServeFile(documentRoot *string, notFoundHandler serviceHandleFunc) se ...@@ -15,7 +14,11 @@ func handleServeFile(documentRoot *string, notFoundHandler serviceHandleFunc) se
// The filepath.Join does Clean traversing directories up // The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) { if !strings.HasPrefix(file, *documentRoot) {
fail500(w, fmt.Errorf("invalid path: "+file, os.ErrInvalid)) fail500(w, &os.PathError{
Op: "open",
Path: file,
Err: os.ErrInvalid,
})
return return
} }
......
package main
import (
"io/ioutil"
"testing"
"net/http/httptest"
"os"
"net/http"
"path/filepath"
)
func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory"
request := &gitRequest{
relativeUriPath: "/static/file",
}
w := httptest.NewRecorder()
handleServeFile(&dir, nil)(w, request)
if w.Code != 404 {
t.Fatal("Expected to receive 404, since no default handler is provided")
}
}
func TestServingDirectory(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
request := &gitRequest{
relativeUriPath: "/",
}
w := httptest.NewRecorder()
handleServeFile(&dir, nil)(w, request)
if w.Code != 404 {
t.Fatal("Expected to receive 404, since we will serve the directory")
}
}
func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory"
request := &gitRequest{
relativeUriPath: "/../../../static/file",
}
w := httptest.NewRecorder()
handleServeFile(&dir, nil)(w, request)
if w.Code != 500 {
t.Fatal("Expected to receive 500, since client provided invalid URI")
}
}
func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory"
request := &gitRequest{
relativeUriPath: "/static/file",
}
executed := false
handleServeFile(&dir, func(w http.ResponseWriter, r *gitRequest) {
executed = (r == request)
})(nil, request)
if !executed {
t.Error("The handler should get executed")
}
}
func TestServingTheActualFile(t *testing.T) {
dir, err := ioutil.TempDir("", "deploy")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeUriPath: "/file",
}
fileContent := "DEPLOY"
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder()
handleServeFile(&dir, nil)(w, request)
if w.Code != 200 {
t.Fatal("Expected to receive 200, since we serve existing file")
}
if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String())
}
}
...@@ -41,7 +41,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -41,7 +41,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
}) })
httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("PATCH", ts.URL + "/url/path", bytes.NewBufferString("REQUEST"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -123,7 +123,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -123,7 +123,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
fmt.Fprint(file, "test") fmt.Fprint(file, "test")
writer.Close() writer.Close()
httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", nil) httpRequest, err := http.NewRequest("PUT", ts.URL + "/url/path", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -12,42 +12,44 @@ import ( ...@@ -12,42 +12,44 @@ import (
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"time"
) )
type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest) type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest)
type upstream struct { type upstream struct {
httpClient *http.Client httpClient *http.Client
authBackend string httpProxy *httputil.ReverseProxy
httpProxy *httputil.ReverseProxy authBackend string
relativeUrlRoot string
} }
type authorizationResponse struct { type authorizationResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git // GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull' // push' and 'git pull'
GL_ID string GL_ID string
// RepoPath is the full path on disk to the Git repository the request is // RepoPath is the full path on disk to the Git repository the request is
// about // about
RepoPath string RepoPath string
// ArchivePath is the full path where we should find/create a cached copy // ArchivePath is the full path where we should find/create a cached copy
// of a requested archive // of a requested archive
ArchivePath string ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a // ArchivePrefix is used to put extracted archive contents in a
// subdirectory // subdirectory
ArchivePrefix string ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check' // CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse. // in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string CommitId string
// StoreLFSPath is provided by the GitLab Rails application // StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed // to mark where the tmp file should be placed
StoreLFSPath string StoreLFSPath string
// LFS object id // LFS object id
LfsOid string LfsOid string
// LFS object size // LFS object size
LfsSize int64 LfsSize int64
// TmpPath is the path where we should store temporary files // TmpPath is the path where we should store temporary files
// This is set by authorization middleware // This is set by authorization middleware
TempPath string TempPath string
} }
// A gitRequest is an *http.Request decorated with attributes returned by the // A gitRequest is an *http.Request decorated with attributes returned by the
...@@ -55,7 +57,7 @@ type authorizationResponse struct { ...@@ -55,7 +57,7 @@ type authorizationResponse struct {
type gitRequest struct { type gitRequest struct {
*http.Request *http.Request
authorizationResponse authorizationResponse
u *upstream u *upstream
// This field contains the URL.Path stripped from RelativeUrlRoot // This field contains the URL.Path stripped from RelativeUrlRoot
relativeUriPath string relativeUriPath string
...@@ -71,11 +73,24 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream ...@@ -71,11 +73,24 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream
authBackend: authBackend, authBackend: authBackend,
httpClient: &http.Client{Transport: authTransport}, httpClient: &http.Client{Transport: authTransport},
httpProxy: httputil.NewSingleHostReverseProxy(u), httpProxy: httputil.NewSingleHostReverseProxy(u),
relativeUrlRoot: "/",
} }
up.httpProxy.Transport = authTransport up.httpProxy.Transport = authTransport
return up return up
} }
func (u *upstream) SetRelativeUrlRoot(relativeUrlRoot string) {
u.relativeUrlRoot = relativeUrlRoot
if !strings.HasSuffix(u.relativeUrlRoot, "/") {
u.relativeUrlRoot += "/"
}
}
func (u *upstream) SetProxyTimeout(timeout time.Duration) {
u.httpClient.Timeout = timeout
}
func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
var g httpRoute var g httpRoute
...@@ -85,7 +100,7 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -85,7 +100,7 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
// Strip prefix and add "/" // Strip prefix and add "/"
// To match against non-relative URL // To match against non-relative URL
// Making it simpler for our matcher // Making it simpler for our matcher
relativeUriPath := "/" + strings.TrimPrefix(r.URL.Path, *relativeUrlRoot) relativeUriPath := "/" + strings.TrimPrefix(r.URL.Path, u.relativeUrlRoot)
// Look for a matching Git service // Look for a matching Git service
foundService := false foundService := false
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment