Commit 9f3966b8 authored by Kamil Trzcinski's avatar Kamil Trzcinski

Make everything simpler, using non-pointer variables.

parent cff33472
...@@ -2,25 +2,19 @@ package main ...@@ -2,25 +2,19 @@ package main
import ( import (
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"path/filepath"
) )
func handleDeployPage(deployPage string, handler serviceHandleFunc) serviceHandleFunc { func handleDeployPage(deployPage *string, handler serviceHandleFunc) serviceHandleFunc {
deployPage, err := filepath.Abs(deployPage)
if err != nil {
log.Fatalln(err)
}
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
data, err := ioutil.ReadFile(deployPage) data, err := ioutil.ReadFile(*deployPage)
if err != nil { if err != nil {
handler(w, r) handler(w, r)
return return
} }
w.Header().Set("Content-Type", "text/html") setNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(data) w.Write(data)
} }
......
...@@ -5,19 +5,14 @@ import ( ...@@ -5,19 +5,14 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
"path/filepath"
) )
type errorPageResponseWriter struct { type errorPageResponseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
status int status int
hijacked bool hijacked bool
} errorPages *string
func newErrorPageResponseWriter(rw http.ResponseWriter) *errorPageResponseWriter {
s := &errorPageResponseWriter{
rw: rw,
}
return s
} }
func (s *errorPageResponseWriter) Header() http.Header { func (s *errorPageResponseWriter) Header() http.Header {
...@@ -41,22 +36,20 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { ...@@ -41,22 +36,20 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status s.status = status
switch s.status { if 400 <= s.status && s.status <= 599 {
case 404, 422, 500, 502: errorPageFile := filepath.Join(*errorPages, fmt.Sprintf("%d.html", s.status))
data, err := ioutil.ReadFile(fmt.Sprintf("public/%d.html", s.status))
if err != nil {
break
}
log.Printf("ErroPage: serving predefined error page: %d", s.status) // check if custom error page exists, serve this page instead
s.hijacked = true if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.rw.Header().Set("Content-Type", "text/html") s.hijacked = true
s.rw.WriteHeader(s.status)
s.rw.Write(data)
return
default: log.Printf("ErrorPage: serving predefined error page: %d", s.status)
break setNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
return
}
} }
s.rw.WriteHeader(status) s.rw.WriteHeader(status)
...@@ -66,10 +59,13 @@ func (s *errorPageResponseWriter) Flush() { ...@@ -66,10 +59,13 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func handleRailsError(handler serviceHandleFunc) serviceHandleFunc { func handleRailsError(errorPages *string, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
rw := newErrorPageResponseWriter(w) rw := errorPageResponseWriter{
rw: w,
errorPages: errorPages,
}
defer rw.Flush() defer rw.Flush()
handler(rw, r) handler(&rw, r)
} }
} }
...@@ -80,3 +80,9 @@ func setHttpPostForm(r *http.Request, values url.Values) { ...@@ -80,3 +80,9 @@ func setHttpPostForm(r *http.Request, values url.Values) {
r.ContentLength = int64(dataBuffer.Len()) r.ContentLength = int64(dataBuffer.Len())
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} }
func setNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
...@@ -13,8 +13,8 @@ type loggingResponseWriter struct { ...@@ -13,8 +13,8 @@ type loggingResponseWriter struct {
started time.Time started time.Time
} }
func newLoggingResponseWriter(rw http.ResponseWriter) *loggingResponseWriter { func newLoggingResponseWriter(rw http.ResponseWriter) loggingResponseWriter {
return &loggingResponseWriter{ return loggingResponseWriter{
rw: rw, rw: rw,
started: time.Now(), started: time.Now(),
} }
......
...@@ -37,6 +37,9 @@ var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentic ...@@ -37,6 +37,9 @@ var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentic
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var relativeUrlRoot = flag.String("relativeUrlRoot", "/", "GitLab relative URL root") var relativeUrlRoot = flag.String("relativeUrlRoot", "/", "GitLab relative URL root")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var deployPage = flag.String("deployPage", "public/index.html", "Path to file that will always be served if present")
var errorPages = flag.String("errorPages", "public/index.html", "The folder containing custom error pages, ie.: 500.html")
type httpRoute struct { type httpRoute struct {
method string method string
...@@ -45,6 +48,8 @@ type httpRoute struct { ...@@ -45,6 +48,8 @@ type httpRoute struct {
} }
// Routing table // Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
var httpRoutes = [...]httpRoute{ var httpRoutes = [...]httpRoute{
httpRoute{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)}, httpRoute{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
...@@ -66,10 +71,11 @@ var httpRoutes = [...]httpRoute{ ...@@ -66,10 +71,11 @@ var httpRoutes = [...]httpRoute{
httpRoute{"", regexp.MustCompile(`^/ci/api/`), proxyRequest}, httpRoute{"", regexp.MustCompile(`^/ci/api/`), proxyRequest},
// Serve static files and forward otherwise // Serve static files and forward otherwise
httpRoute{"", nil, handleServeFile("public", httpRoute{"", nil, handleServeFile(documentRoot,
handleDeployPage("public/index.html", handleDeployPage(deployPage,
handleRailsError(proxyRequest), handleRailsError(errorPages,
))}, proxyRequest,
)))},
} }
func main() { func main() {
......
...@@ -23,5 +23,5 @@ func proxyRequest(w http.ResponseWriter, r *gitRequest) { ...@@ -23,5 +23,5 @@ func proxyRequest(w http.ResponseWriter, r *gitRequest) {
req.Header.Set("Gitlab-Workhorse", Version) req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req) rw := newSendFileResponseWriter(w, &req)
defer rw.Flush() defer rw.Flush()
r.u.httpProxy.ServeHTTP(rw, &req) r.u.httpProxy.ServeHTTP(&rw, &req)
} }
...@@ -20,8 +20,8 @@ type sendFileResponseWriter struct { ...@@ -20,8 +20,8 @@ type sendFileResponseWriter struct {
req *http.Request req *http.Request
} }
func newSendFileResponseWriter(rw http.ResponseWriter, req *http.Request) *sendFileResponseWriter { func newSendFileResponseWriter(rw http.ResponseWriter, req *http.Request) sendFileResponseWriter {
s := &sendFileResponseWriter{ s := sendFileResponseWriter{
rw: rw, rw: rw,
req: req, req: req,
} }
......
...@@ -9,21 +9,12 @@ import ( ...@@ -9,21 +9,12 @@ import (
"strings" "strings"
) )
func handleServeFile(rootDir string, notFoundHandler serviceHandleFunc) serviceHandleFunc { func handleServeFile(documentRoot *string, notFoundHandler serviceHandleFunc) serviceHandleFunc {
rootDir, err := filepath.Abs(rootDir)
if err != nil {
log.Fatalln(err)
}
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
file := filepath.Join(rootDir, r.relativeUriPath) file := filepath.Join(*documentRoot, r.relativeUriPath)
file, err := filepath.Abs(file)
if err != nil {
fail500(w, fmt.Errorf("invalid path:"+file, err))
return
}
if !strings.HasPrefix(file, rootDir) { // The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) {
fail500(w, fmt.Errorf("invalid path: "+file, os.ErrInvalid)) fail500(w, fmt.Errorf("invalid path: "+file, os.ErrInvalid))
return return
} }
......
...@@ -54,9 +54,11 @@ type authorizationResponse struct { ...@@ -54,9 +54,11 @@ type authorizationResponse struct {
// GitLab Rails application. // GitLab Rails application.
type gitRequest struct { type gitRequest struct {
*http.Request *http.Request
relativeUriPath string
authorizationResponse authorizationResponse
u *upstream u *upstream
// This field contains the URL.Path stripped from RelativeUrlRoot
relativeUriPath string
} }
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream { func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
...@@ -81,7 +83,9 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -81,7 +83,9 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
defer w.Log(r) defer w.Log(r)
// Strip prefix and add "/" // Strip prefix and add "/"
relativeUriPath := "/" + strings.TrimPrefix(r.RequestURI, *relativeUrlRoot) // To match against non-relative URL
// Making it simpler for our matcher
relativeUriPath := "/" + strings.TrimPrefix(r.URL.Path, *relativeUrlRoot)
// Look for a matching Git service // Look for a matching Git service
foundService := false foundService := false
...@@ -98,7 +102,7 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -98,7 +102,7 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
if !foundService { if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt // The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found. // says we must return 403 if no matching service is found.
http.Error(w, "Forbidden", 403) http.Error(&w, "Forbidden", 403)
return return
} }
...@@ -108,5 +112,5 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -108,5 +112,5 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
u: u, u: u,
} }
g.handleFunc(w, &request) g.handleFunc(&w, &request)
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment