From c42e60a3d261288287dab5cb8ebec237e574c1e0 Mon Sep 17 00:00:00 2001
From: Abiola Ibrahim <abiola89@gmail.com>
Date: Sun, 7 Jun 2015 20:31:34 +0100
Subject: [PATCH] Git: fix for data races.

---
 config/setup/git_test.go              |  5 ++--
 middleware/git/git.go                 | 22 ++++------------
 middleware/git/git_test.go            |  2 +-
 middleware/git/gittest/gittest.go     | 12 +++++++++
 middleware/git/logger.go              | 38 +++++++++++++++++++++++++++
 middleware/git/service.go             |  2 +-
 middleware/git/webhook/github_hook.go | 18 +++++--------
 7 files changed, 65 insertions(+), 34 deletions(-)
 create mode 100644 middleware/git/logger.go

diff --git a/config/setup/git_test.go b/config/setup/git_test.go
index 5c49c5eb..226f8305 100644
--- a/config/setup/git_test.go
+++ b/config/setup/git_test.go
@@ -2,7 +2,6 @@ package setup
 
 import (
 	"io/ioutil"
-	"log"
 	"strings"
 	"testing"
 	"time"
@@ -42,7 +41,7 @@ func TestIntervals(t *testing.T) {
 	}
 
 	for i, test := range tests {
-		git.Logger = nil
+		git.SetLogger(gittest.NewLogger(gittest.Open("file")))
 
 		c1 := newTestController(test)
 		repo, err := gitParse(c1)
@@ -61,7 +60,7 @@ func TestIntervals(t *testing.T) {
 
 		// switch logger to test file
 		logFile := gittest.Open("file")
-		git.Logger = log.New(logFile, "", 0)
+		git.SetLogger(gittest.NewLogger(logFile))
 
 		// sleep for the interval
 		gittest.Sleep(repo.Interval)
diff --git a/middleware/git/git.go b/middleware/git/git.go
index 74db1bb7..85335483 100644
--- a/middleware/git/git.go
+++ b/middleware/git/git.go
@@ -3,7 +3,6 @@ package git
 import (
 	"bytes"
 	"fmt"
-	"log"
 	"os"
 	"strings"
 	"sync"
@@ -30,21 +29,10 @@ var shell string
 // git requirements.
 var initMutex = sync.Mutex{}
 
-// Logger is used to log errors; if nil, the default log.Logger is used.
-var Logger *log.Logger
-
 // Services holds all git pulling services and provides the function to
 // stop them.
 var Services = &services{}
 
-// logger is an helper function to retrieve the available logger
-func logger() *log.Logger {
-	if Logger == nil {
-		Logger = log.New(os.Stderr, "", log.LstdFlags)
-	}
-	return Logger
-}
-
 // Repo is the structure that holds required information
 // of a git repository.
 type Repo struct {
@@ -84,7 +72,7 @@ func (r *Repo) Pull() error {
 		if err = r.pull(); err == nil {
 			break
 		}
-		logger().Println(err)
+		Logger().Println(err)
 	}
 
 	if err != nil {
@@ -94,7 +82,7 @@ func (r *Repo) Pull() error {
 	// check if there are new changes,
 	// then execute post pull command
 	if r.lastCommit == lastCommit {
-		logger().Println("No new changes.")
+		Logger().Println("No new changes.")
 		return nil
 	}
 	return r.postPullCommand()
@@ -121,7 +109,7 @@ func (r *Repo) pull() error {
 	if err = runCmd(gitBinary, params, dir); err == nil {
 		r.pulled = true
 		r.lastPull = time.Now()
-		logger().Printf("%v pulled.\n", r.URL)
+		Logger().Printf("%v pulled.\n", r.URL)
 		r.lastCommit, err = r.getMostRecentCommit()
 	}
 	return err
@@ -162,7 +150,7 @@ func (r *Repo) pullWithKey(params []string) error {
 	if err = runCmd(script.Name(), nil, dir); err == nil {
 		r.pulled = true
 		r.lastPull = time.Now()
-		logger().Printf("%v pulled.\n", r.URL)
+		Logger().Printf("%v pulled.\n", r.URL)
 		r.lastCommit, err = r.getMostRecentCommit()
 	}
 	return err
@@ -241,7 +229,7 @@ func (r *Repo) postPullCommand() error {
 	}
 
 	if err = runCmd(c, args, r.Path); err == nil {
-		logger().Printf("Command %v successful.\n", r.Then)
+		Logger().Printf("Command %v successful.\n", r.Then)
 	}
 	return err
 }
diff --git a/middleware/git/git_test.go b/middleware/git/git_test.go
index 695e3b5b..5ce183f1 100644
--- a/middleware/git/git_test.go
+++ b/middleware/git/git_test.go
@@ -73,7 +73,7 @@ func TestGit(t *testing.T) {
 
 	// pull with success
 	logFile := gittest.Open("file")
-	Logger = log.New(logFile, "", 0)
+	SetLogger(log.New(logFile, "", 0))
 	tests := []struct {
 		repo   *Repo
 		output string
diff --git a/middleware/git/gittest/gittest.go b/middleware/git/gittest/gittest.go
index a275b281..94f6d045 100644
--- a/middleware/git/gittest/gittest.go
+++ b/middleware/git/gittest/gittest.go
@@ -4,7 +4,9 @@ package gittest
 
 import (
 	"io"
+	"log"
 	"os"
+	"sync"
 	"time"
 
 	"github.com/mholt/caddy/middleware/git/gitos"
@@ -39,12 +41,18 @@ func Sleep(d time.Duration) {
 	FakeOS.Sleep(d)
 }
 
+// NewLogger creates a logger that logs to f
+func NewLogger(f gitos.File) *log.Logger {
+	return log.New(f, "", 0)
+}
+
 // fakeFile is a mock gitos.File.
 type fakeFile struct {
 	name    string
 	dir     bool
 	content []byte
 	info    fakeInfo
+	sync.Mutex
 }
 
 func (f fakeFile) Name() string {
@@ -65,6 +73,8 @@ func (f fakeFile) Chmod(mode os.FileMode) error {
 }
 
 func (f *fakeFile) Read(b []byte) (int, error) {
+	f.Lock()
+	defer f.Unlock()
 	if len(f.content) == 0 {
 		return 0, io.EOF
 	}
@@ -74,6 +84,8 @@ func (f *fakeFile) Read(b []byte) (int, error) {
 }
 
 func (f *fakeFile) Write(b []byte) (int, error) {
+	f.Lock()
+	defer f.Unlock()
 	f.content = append(f.content, b...)
 	return len(b), nil
 }
diff --git a/middleware/git/logger.go b/middleware/git/logger.go
new file mode 100644
index 00000000..2500239c
--- /dev/null
+++ b/middleware/git/logger.go
@@ -0,0 +1,38 @@
+package git
+
+import (
+	"log"
+	"os"
+	"sync"
+)
+
+// logger is used to log errors
+var logger = &gitLogger{l: log.New(os.Stderr, "", log.LstdFlags)}
+
+// gitLogger wraps log.Logger with mutex for thread safety.
+type gitLogger struct {
+	l *log.Logger
+	sync.RWMutex
+}
+
+func (g *gitLogger) logger() *log.Logger {
+	g.RLock()
+	defer g.RUnlock()
+	return g.l
+}
+
+func (g *gitLogger) setLogger(l *log.Logger) {
+	g.Lock()
+	g.l = l
+	g.Unlock()
+}
+
+// Logger gets the currently available logger
+func Logger() *log.Logger {
+	return logger.logger()
+}
+
+// SetLogger sets the current logger to l
+func SetLogger(l *log.Logger) {
+	logger.setLogger(l)
+}
diff --git a/middleware/git/service.go b/middleware/git/service.go
index 224f37a6..89b63c65 100644
--- a/middleware/git/service.go
+++ b/middleware/git/service.go
@@ -27,7 +27,7 @@ func Start(repo *Repo) {
 			case <-s.ticker.C():
 				err := repo.Pull()
 				if err != nil {
-					logger().Println(err)
+					Logger().Println(err)
 				}
 			case <-s.halt:
 				s.ticker.Stop()
diff --git a/middleware/git/webhook/github_hook.go b/middleware/git/webhook/github_hook.go
index e3ca70fd..8689a260 100644
--- a/middleware/git/webhook/github_hook.go
+++ b/middleware/git/webhook/github_hook.go
@@ -6,12 +6,12 @@ import (
 	"encoding/hex"
 	"encoding/json"
 	"errors"
-	"github.com/mholt/caddy/middleware/git"
 	"io/ioutil"
 	"log"
 	"net/http"
-	"os"
 	"strings"
+
+	"github.com/mholt/caddy/middleware/git"
 )
 
 type GithubHook struct{}
@@ -28,15 +28,9 @@ type ghPush struct {
 	Ref string `json:"ref"`
 }
 
-// Logger is used to log errors; if nil, the default log.Logger is used.
-var Logger *log.Logger
-
 // logger is an helper function to retrieve the available logger
 func logger() *log.Logger {
-	if Logger == nil {
-		Logger = log.New(os.Stderr, "", log.LstdFlags)
-	}
-	return Logger
+	return git.Logger()
 }
 
 func (g GithubHook) DoesHandle(h http.Header) bool {
@@ -97,7 +91,7 @@ func (g GithubHook) handleSignature(r *http.Request, body []byte, secret string)
 	signature := r.Header.Get("X-Hub-Signature")
 	if signature != "" {
 		if secret == "" {
-			logger().Print("Unable to verify request signature. Secret not set in caddyfile!")
+			logger().Print("Unable to verify request signature. Secret not set in caddyfile!\n")
 		} else {
 			mac := hmac.New(sha1.New, []byte(secret))
 			mac.Write(body)
@@ -129,7 +123,7 @@ func (g GithubHook) handlePush(body []byte, repo *git.Repo) error {
 
 	branch := refSlice[2]
 	if branch == repo.Branch {
-		logger().Print("Received pull notification for the tracking branch, updating...")
+		logger().Print("Received pull notification for the tracking branch, updating...\n")
 		repo.Pull()
 	}
 
@@ -148,7 +142,7 @@ func (g GithubHook) handleRelease(body []byte, repo *git.Repo) error {
 		return errors.New("The release request contained an invalid TagName.")
 	}
 
-	logger().Printf("Received new release '%s'. -> Updating local repository to this release.", release.Release.Name)
+	logger().Printf("Received new release '%s'. -> Updating local repository to this release.\n", release.Release.Name)
 
 	// Update the local branch to the release tag name
 	// this will pull the release tag.
-- 
2.30.9