Commit 9fd59b22 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Use http.HandlerFunc and move API into package

parent bd748961
...@@ -5,6 +5,7 @@ In this file we handle 'git archive' downloads ...@@ -5,6 +5,7 @@ In this file we handle 'git archive' downloads
package main package main
import ( import (
"./internal/api"
"./internal/helper" "./internal/helper"
"fmt" "fmt"
"io" "io"
...@@ -18,7 +19,7 @@ import ( ...@@ -18,7 +19,7 @@ import (
"time" "time"
) )
func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) { func handleGetArchive(w http.ResponseWriter, r *http.Request, a *api.Response) {
var format string var format string
urlPath := r.URL.Path urlPath := r.URL.Path
switch filepath.Base(urlPath) { switch filepath.Base(urlPath) {
...@@ -31,7 +32,7 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -31,7 +32,7 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) {
case "archive.tar.bz2": case "archive.tar.bz2":
format = "tar.bz2" format = "tar.bz2"
default: default:
fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath)) helper.Fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath))
return return
} }
...@@ -54,7 +55,7 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -54,7 +55,7 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) {
// to finalize the cached archive. // to finalize the cached archive.
tempFile, err := prepareArchiveTempfile(path.Dir(a.ArchivePath), archiveFilename) tempFile, err := prepareArchiveTempfile(path.Dir(a.ArchivePath), archiveFilename)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err))
return return
} }
defer tempFile.Close() defer tempFile.Close()
...@@ -65,12 +66,12 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -65,12 +66,12 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) {
archiveCmd := gitCommand("", "git", "--git-dir="+a.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+a.ArchivePrefix+"/", a.CommitId) archiveCmd := gitCommand("", "git", "--git-dir="+a.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+a.ArchivePrefix+"/", a.CommitId)
archiveStdout, err := archiveCmd.StdoutPipe() archiveStdout, err := archiveCmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err))
return return
} }
defer archiveStdout.Close() defer archiveStdout.Close()
if err := archiveCmd.Start(); err != nil { if err := archiveCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err))
return return
} }
defer cleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up defer cleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up
...@@ -83,13 +84,13 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -83,13 +84,13 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) {
stdout, err = compressCmd.StdoutPipe() stdout, err = compressCmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
if err := compressCmd.Start(); err != nil { if err := compressCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err))
return return
} }
defer compressCmd.Wait() defer compressCmd.Wait()
......
package main package main
import ( import (
"./internal/api"
"net/http" "net/http"
) )
func artifactsAuthorizeHandler(api *API, h httpHandleFunc) httpHandleFunc { func artifactsAuthorizeHandler(myAPI *api.API, h http.HandlerFunc) http.HandlerFunc {
return api.preAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *apiResponse) { return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
r.Header.Set(tempPathHeader, a.TempPath) r.Header.Set(tempPathHeader, a.TempPath)
h(w, r) h(w, r)
}, "/authorize") }, "/authorize")
......
package main
import (
"./internal/proxy"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
)
func (api *API) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
url := *api.URL
url.Path = r.URL.RequestURI() + suffix
authReq := &http.Request{
Method: r.Method,
URL: &url,
Header: proxy.HeaderClone(r.Header),
}
if body != nil {
authReq.Body = ioutil.NopCloser(body)
}
// Clean some headers when issuing a new request without body
if body == nil {
authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding")
authReq.Header.Del("Content-Length")
authReq.Header.Del("Content-Disposition")
authReq.Header.Del("Accept-Encoding")
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
authReq.Header.Del("Transfer-Encoding")
authReq.Header.Del("Connection")
authReq.Header.Del("Keep-Alive")
authReq.Header.Del("Proxy-Authenticate")
authReq.Header.Del("Proxy-Authorization")
authReq.Header.Del("Te")
authReq.Header.Del("Trailers")
authReq.Header.Del("Upgrade")
}
// Also forward the Host header, which is excluded from the Header map by the http libary.
// This allows the Host header received by the backend to be consistent with other
// requests not going through gitlab-workhorse.
authReq.Host = r.Host
// Set a custom header for the request. This can be used in some
// configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", Version)
return authReq, nil
}
func (api *API) preAuthorizeHandler(h serviceHandleFunc, suffix string) httpHandleFunc {
return func(w http.ResponseWriter, r *http.Request) {
authReq, err := api.newUpstreamRequest(r, nil, suffix)
if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
return
}
authResponse, err := api.Do(authReq)
if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err))
return
}
defer authResponse.Body.Close()
if authResponse.StatusCode != 200 {
// The Git request is not allowed by the backend. Maybe the
// client needs to send HTTP Basic credentials. Forward the
// response from the auth backend to our client. This includes
// the 'WWW-Authenticate' header that acts as a hint that
// Basic auth credentials are needed.
for k, v := range authResponse.Header {
// Accomodate broken clients that do case-sensitive header lookup
if k == "Www-Authenticate" {
w.Header()["WWW-Authenticate"] = v
} else {
w.Header()[k] = v
}
}
w.WriteHeader(authResponse.StatusCode)
io.Copy(w, authResponse.Body)
return
}
a := &apiResponse{}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(authResponse.Body).Decode(a); err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return
}
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
authResponse.Body.Close()
// Negotiate authentication (Kerberos) may need to return a WWW-Authenticate
// header to the client even in case of success as per RFC4559.
for k, v := range authResponse.Header {
// Case-insensitive comparison as per RFC7230
if strings.EqualFold(k, "WWW-Authenticate") {
w.Header()[k] = v
}
}
h(w, r, a)
}
}
package main package main
import ( import (
"./internal/api"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -8,7 +9,7 @@ import ( ...@@ -8,7 +9,7 @@ import (
"testing" "testing"
) )
func okHandler(w http.ResponseWriter, _ *http.Request, _ *apiResponse) { func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
w.WriteHeader(201) w.WriteHeader(201)
fmt.Fprint(w, "{\"status\":\"ok\"}") fmt.Fprint(w, "{\"status\":\"ok\"}")
} }
...@@ -26,7 +27,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -26,7 +27,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
api := newUpstream(ts.URL, "").API api := newUpstream(ts.URL, "").API
response := httptest.NewRecorder() response := httptest.NewRecorder()
api.preAuthorizeHandler(okHandler, suffix)(response, httpRequest) api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest)
assertResponseCode(t, response, expectedCode) assertResponseCode(t, response, expectedCode)
return response return response
} }
...@@ -35,7 +36,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) { ...@@ -35,7 +36,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) {
runPreAuthorizeHandler( runPreAuthorizeHandler(
t, "/authorize", t, "/authorize",
regexp.MustCompile(`/authorize\z`), regexp.MustCompile(`/authorize\z`),
&apiResponse{}, &api.Response{},
200, 201) 200, 201)
} }
...@@ -43,7 +44,7 @@ func TestPreAuthorizeSuffix(t *testing.T) { ...@@ -43,7 +44,7 @@ func TestPreAuthorizeSuffix(t *testing.T) {
runPreAuthorizeHandler( runPreAuthorizeHandler(
t, "/different-authorize", t, "/different-authorize",
regexp.MustCompile(`/authorize\z`), regexp.MustCompile(`/authorize\z`),
&apiResponse{}, &api.Response{},
200, 404) 200, 404)
} }
......
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
"path/filepath" "path/filepath"
) )
func handleDeployPage(documentRoot *string, handler httpHandleFunc) httpHandleFunc { func handleDeployPage(documentRoot *string, handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
deployPage := filepath.Join(*documentRoot, "index.html") deployPage := filepath.Join(*documentRoot, "index.html")
data, err := ioutil.ReadFile(deployPage) data, err := ioutil.ReadFile(deployPage)
......
...@@ -2,7 +2,7 @@ package main ...@@ -2,7 +2,7 @@ package main
import "net/http" import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler httpHandleFunc) httpHandleFunc { func handleDevelopmentMode(developmentMode *bool, handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !*developmentMode { if !*developmentMode {
http.NotFound(w, r) http.NotFound(w, r)
......
...@@ -59,7 +59,7 @@ func (s *errorPageResponseWriter) Flush() { ...@@ -59,7 +59,7 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func handleRailsError(documentRoot *string, handler http.Handler) httpHandleFunc { func handleRailsError(documentRoot *string, handler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{ rw := errorPageResponseWriter{
rw: w, rw: w,
......
...@@ -21,7 +21,7 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -21,7 +21,7 @@ func TestIfErrorPageIsPresented(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := httpHandleFunc(func(w http.ResponseWriter, _ *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, "Not Found") fmt.Fprint(w, "Not Found")
}) })
...@@ -41,7 +41,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { ...@@ -41,7 +41,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
errorResponse := "ERROR" errorResponse := "ERROR"
h := httpHandleFunc(func(w http.ResponseWriter, _ *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, errorResponse) fmt.Fprint(w, errorResponse)
}) })
......
...@@ -5,6 +5,7 @@ In this file we handle the Git 'smart HTTP' protocol ...@@ -5,6 +5,7 @@ In this file we handle the Git 'smart HTTP' protocol
package main package main
import ( import (
"./internal/api"
"./internal/helper" "./internal/helper"
"errors" "errors"
"fmt" "fmt"
...@@ -27,10 +28,10 @@ func looksLikeRepo(p string) bool { ...@@ -27,10 +28,10 @@ func looksLikeRepo(p string) bool {
return true return true
} }
func repoPreAuthorizeHandler(api *API, handleFunc serviceHandleFunc) httpHandleFunc { func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.HandlerFunc {
return api.preAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *apiResponse) { return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.RepoPath == "" { if a.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) helper.Fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return return
} }
...@@ -43,7 +44,7 @@ func repoPreAuthorizeHandler(api *API, handleFunc serviceHandleFunc) httpHandleF ...@@ -43,7 +44,7 @@ func repoPreAuthorizeHandler(api *API, handleFunc serviceHandleFunc) httpHandleF
}, "") }, "")
} }
func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) { func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) {
rpc := r.URL.Query().Get("service") rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
...@@ -55,12 +56,12 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -55,12 +56,12 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) {
cmd := gitCommand(a.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", a.RepoPath) cmd := gitCommand(a.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", a.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err))
return return
} }
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
...@@ -87,14 +88,14 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -87,14 +88,14 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) {
} }
} }
func handlePostRPC(w http.ResponseWriter, r *http.Request, a *apiResponse) { func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
var err error var err error
// Get Git action from URL // Get Git action from URL
action := filepath.Base(r.URL.Path) action := filepath.Base(r.URL.Path)
if !(action == "git-upload-pack" || action == "git-receive-pack") { if !(action == "git-upload-pack" || action == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path)) helper.Fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path))
return return
} }
...@@ -102,25 +103,25 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -102,25 +103,25 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *apiResponse) {
cmd := gitCommand(a.GL_ID, "git", subCommand(action), "--stateless-rpc", a.RepoPath) cmd := gitCommand(a.GL_ID, "git", subCommand(action), "--stateless-rpc", a.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err))
return return
} }
defer stdin.Close() defer stdin.Close()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err))
return return
} }
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Write the client request body to Git's standard input // Write the client request body to Git's standard input
if _, err := io.Copy(stdin, r.Body); err != nil { if _, err := io.Copy(stdin, r.Body); err != nil {
fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err))
return return
} }
// Signal to the Git subprocess that no more data is coming // Signal to the Git subprocess that no more data is coming
......
package main package main
import ( import (
"./internal/helper"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
) )
func contentEncodingHandler(h httpHandleFunc) httpHandleFunc { func contentEncodingHandler(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var body io.ReadCloser var body io.ReadCloser
var err error var err error
...@@ -24,7 +25,7 @@ func contentEncodingHandler(h httpHandleFunc) httpHandleFunc { ...@@ -24,7 +25,7 @@ func contentEncodingHandler(h httpHandleFunc) httpHandleFunc {
} }
if err != nil { if err != nil {
fail500(w, fmt.Errorf("contentEncodingHandler: %v", err)) helper.Fail500(w, fmt.Errorf("contentEncodingHandler: %v", err))
return return
} }
defer body.Close() defer body.Close()
......
...@@ -5,7 +5,6 @@ Miscellaneous helpers: logging, errors, subprocesses ...@@ -5,7 +5,6 @@ Miscellaneous helpers: logging, errors, subprocesses
package main package main
import ( import (
"./internal/helper"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
...@@ -14,11 +13,6 @@ import ( ...@@ -14,11 +13,6 @@ import (
"syscall" "syscall"
) )
func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
helper.LogError(err)
}
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) { if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error // Force client to disconnect if we render request error
......
...@@ -3,9 +3,15 @@ package helper ...@@ -3,9 +3,15 @@ package helper
import ( import (
"errors" "errors"
"log" "log"
"net/http"
"os" "os"
) )
func Fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
LogError(err)
}
func LogError(err error) { func LogError(err error) {
log.Printf("error: %v", err) log.Printf("error: %v", err)
} }
......
...@@ -5,6 +5,8 @@ In this file we handle git lfs objects downloads and uploads ...@@ -5,6 +5,8 @@ In this file we handle git lfs objects downloads and uploads
package main package main
import ( import (
"./internal/api"
"./internal/helper"
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
...@@ -17,21 +19,21 @@ import ( ...@@ -17,21 +19,21 @@ import (
"path/filepath" "path/filepath"
) )
func lfsAuthorizeHandler(api *API, handleFunc serviceHandleFunc) httpHandleFunc { func lfsAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.HandlerFunc {
return api.preAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *apiResponse) { return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.StoreLFSPath == "" { if a.StoreLFSPath == "" {
fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty")) helper.Fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return return
} }
if a.LfsOid == "" { if a.LfsOid == "" {
fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty")) helper.Fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return return
} }
if err := os.MkdirAll(a.StoreLFSPath, 0700); err != nil { if err := os.MkdirAll(a.StoreLFSPath, 0700); err != nil {
fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdia StoreLFSPath: %v", err)) helper.Fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdia StoreLFSPath: %v", err))
return return
} }
...@@ -39,11 +41,11 @@ func lfsAuthorizeHandler(api *API, handleFunc serviceHandleFunc) httpHandleFunc ...@@ -39,11 +41,11 @@ func lfsAuthorizeHandler(api *API, handleFunc serviceHandleFunc) httpHandleFunc
}, "/authorize") }, "/authorize")
} }
func handleStoreLfsObject(h http.Handler) serviceHandleFunc { func handleStoreLfsObject(h http.Handler) api.HandleFunc {
return func(w http.ResponseWriter, r *http.Request, a *apiResponse) { return func(w http.ResponseWriter, r *http.Request, a *api.Response) {
file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid) file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err)) helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return return
} }
defer os.Remove(file.Name()) defer os.Remove(file.Name())
...@@ -54,19 +56,19 @@ func handleStoreLfsObject(h http.Handler) serviceHandleFunc { ...@@ -54,19 +56,19 @@ func handleStoreLfsObject(h http.Handler) serviceHandleFunc {
written, err := io.Copy(hw, r.Body) written, err := io.Copy(hw, r.Body)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err)) helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return return
} }
file.Close() file.Close()
if written != a.LfsSize { if written != a.LfsSize {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", a.LfsSize, written)) helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", a.LfsSize, written))
return return
} }
shaStr := hex.EncodeToString(hash.Sum(nil)) shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != a.LfsOid { if shaStr != a.LfsOid {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", a.LfsOid, shaStr)) helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", a.LfsOid, shaStr))
return return
} }
......
...@@ -46,12 +46,6 @@ type httpRoute struct { ...@@ -46,12 +46,6 @@ type httpRoute struct {
handler http.Handler handler http.Handler
} }
type httpHandleFunc func(http.ResponseWriter, *http.Request)
func (h httpHandleFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h(w, r)
}
const projectPattern = `^/[^/]+/[^/]+/` const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/` const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
......
package main package main
import ( import (
"./internal/api"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
...@@ -340,7 +341,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) { ...@@ -340,7 +341,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) {
} }
func gitOkBody(t *testing.T) interface{} { func gitOkBody(t *testing.T) interface{} {
return &apiResponse{ return &api.Response{
GL_ID: "user-123", GL_ID: "user-123",
RepoPath: repoPath(t), RepoPath: repoPath(t),
} }
...@@ -353,7 +354,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} { ...@@ -353,7 +354,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} {
} }
archivePath := path.Join(cwd, cacheDir, archiveName) archivePath := path.Join(cwd, cacheDir, archiveName)
return &apiResponse{ return &api.Response{
RepoPath: repoPath(t), RepoPath: repoPath(t),
ArchivePath: archivePath, ArchivePath: archivePath,
CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd", CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd",
......
...@@ -17,13 +17,13 @@ const ( ...@@ -17,13 +17,13 @@ const (
CacheExpireMax CacheExpireMax
) )
func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler httpHandleFunc) httpHandleFunc { func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
file := filepath.Join(*documentRoot, u.relativeURIPath(cleanURIPath(r.URL.Path))) file := filepath.Join(*documentRoot, u.relativeURIPath(cleanURIPath(r.URL.Path)))
// The filepath.Join does Clean traversing directories up // The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) { if !strings.HasPrefix(file, *documentRoot) {
fail500(w, &os.PathError{ helper.Fail500(w, &os.PathError{
Op: "open", Op: "open",
Path: file, Path: file,
Err: os.ErrInvalid, Err: os.ErrInvalid,
......
package main package main
import ( import (
"./internal/helper"
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
...@@ -85,11 +86,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -85,11 +86,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
return cleanup, nil return cleanup, nil
} }
func handleFileUploads(h http.Handler) httpHandleFunc { func handleFileUploads(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tempPath := r.Header.Get(tempPathHeader) tempPath := r.Header.Get(tempPathHeader)
if tempPath == "" { if tempPath == "" {
fail500(w, errors.New("handleFileUploads: TempPath empty")) helper.Fail500(w, errors.New("handleFileUploads: TempPath empty"))
return return
} }
r.Header.Del(tempPathHeader) r.Header.Del(tempPathHeader)
...@@ -104,7 +105,7 @@ func handleFileUploads(h http.Handler) httpHandleFunc { ...@@ -104,7 +105,7 @@ func handleFileUploads(h http.Handler) httpHandleFunc {
if err == http.ErrNotMultipart { if err == http.ErrNotMultipart {
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
} else { } else {
fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err)) helper.Fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
} }
return return
} }
......
...@@ -7,6 +7,7 @@ In this file we handle request routing and interaction with the authBackend. ...@@ -7,6 +7,7 @@ In this file we handle request routing and interaction with the authBackend.
package main package main
import ( import (
"./internal/api"
"./internal/proxy" "./internal/proxy"
"fmt" "fmt"
"log" "log"
...@@ -17,48 +18,13 @@ import ( ...@@ -17,48 +18,13 @@ import (
"time" "time"
) )
type serviceHandleFunc func(http.ResponseWriter, *http.Request, *apiResponse)
type API struct {
*http.Client
*url.URL
}
type upstream struct { type upstream struct {
API *API API *api.API
Proxy *proxy.Proxy Proxy *proxy.Proxy
authBackend string authBackend string
relativeURLRoot string relativeURLRoot string
} }
type apiResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
}
func newUpstream(authBackend string, authSocket string) *upstream { func newUpstream(authBackend string, authSocket string) *upstream {
parsedURL, err := url.Parse(authBackend) parsedURL, err := url.Parse(authBackend)
if err != nil { if err != nil {
...@@ -88,8 +54,12 @@ func newUpstream(authBackend string, authSocket string) *upstream { ...@@ -88,8 +54,12 @@ func newUpstream(authBackend string, authSocket string) *upstream {
proxyTransport := proxy.NewRoundTripper(authTransport) proxyTransport := proxy.NewRoundTripper(authTransport)
up := &upstream{ up := &upstream{
authBackend: authBackend, authBackend: authBackend,
API: &API{Client: &http.Client{Transport: proxyTransport}, URL: parsedURL}, API: &api.API{
Client: &http.Client{Transport: proxyTransport},
URL: parsedURL,
Version: Version,
},
Proxy: proxy.NewProxy(parsedURL, proxyTransport, Version), Proxy: proxy.NewProxy(parsedURL, proxyTransport, Version),
relativeURLRoot: relativeURLRoot, relativeURLRoot: relativeURLRoot,
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment