Commit 0e777472 authored by Kim "BKC" Carlbäcker's avatar Kim "BKC" Carlbäcker Committed by Nick Thomas

Bug-fix for redis connection

parent a806da23
...@@ -105,12 +105,16 @@ SentinelMaster = "mymaster" ...@@ -105,12 +105,16 @@ SentinelMaster = "mymaster"
Optional fields are as follows: Optional fields are as follows:
``` ```
[redis] [redis]
ReadTimeout = 1000 DB = 0
ReadTimeout = "1s"
KeepAlivePeriod = "5m"
MaxIdle = 1 MaxIdle = 1
MaxActive = 1 MaxActive = 1
``` ```
- `ReadTimeout` is how many milliseconds that a redis read-command can take. Defaults to `1000` - `DB` is the Database to connect to. Defaults to `0`
- `ReadTimeout` is how long a redis read-command can take. Defaults to `1s`
- `KeepAlivePeriod` is how long the redis connection is to be kept alive without anything flowing through it. Defaults to `5m`
- `MaxIdle` is how many idle connections can be in the redis-pool at once. Defaults to 1 - `MaxIdle` is how many idle connections can be in the redis-pool at once. Defaults to 1
- `MaxActive` is how many connections the pool can keep. Defaults to 1 - `MaxActive` is how many connections the pool can keep. Defaults to 1
......
...@@ -17,14 +17,27 @@ func (u *TomlURL) UnmarshalText(text []byte) error { ...@@ -17,14 +17,27 @@ func (u *TomlURL) UnmarshalText(text []byte) error {
return err return err
} }
type TomlDuration struct {
time.Duration
}
func (d *TomlDuration) UnmarshalTest(text []byte) error {
temp, err := time.ParseDuration(string(text))
d.Duration = temp
return err
}
type RedisConfig struct { type RedisConfig struct {
URL TomlURL URL TomlURL
Sentinel []TomlURL Sentinel []TomlURL
SentinelMaster string SentinelMaster string
Password string Password string
ReadTimeout *int DB *int
MaxIdle *int ReadTimeout *TomlDuration
MaxActive *int WriteTimeout *TomlDuration
KeepAlivePeriod *TomlDuration
MaxIdle *int
MaxActive *int
} }
type Config struct { type Config struct {
......
package redis package redis
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
"strings" "strings"
...@@ -34,7 +33,7 @@ var ( ...@@ -34,7 +33,7 @@ var (
totalMessages = prometheus.NewCounter( totalMessages = prometheus.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "gitlab_workhorse_keywather_total_messages", Name: "gitlab_workhorse_keywather_total_messages",
Help: "How many messages gitlab-workhorse has recieved in total on pubsub.", Help: "How many messages gitlab-workhorse has received in total on pubsub.",
}, },
) )
) )
...@@ -58,13 +57,11 @@ type KeyChan struct { ...@@ -58,13 +57,11 @@ type KeyChan struct {
Chan chan string Chan chan string
} }
func processInner(conn redis.Conn) { func processInner(conn redis.Conn) error {
redisReconnectTimeout.Reset()
defer conn.Close() defer conn.Close()
psc := redis.PubSubConn{Conn: conn} psc := redis.PubSubConn{Conn: conn}
if err := psc.Subscribe(keySubChannel); err != nil { if err := psc.Subscribe(keySubChannel); err != nil {
return return err
} }
defer psc.Unsubscribe(keySubChannel) defer psc.Unsubscribe(keySubChannel)
...@@ -72,20 +69,38 @@ func processInner(conn redis.Conn) { ...@@ -72,20 +69,38 @@ func processInner(conn redis.Conn) {
switch v := psc.Receive().(type) { switch v := psc.Receive().(type) {
case redis.Message: case redis.Message:
totalMessages.Inc() totalMessages.Inc()
msg := strings.SplitN(string(v.Data), "=", 2) dataStr := string(v.Data)
msg := strings.SplitN(dataStr, "=", 2)
if len(msg) != 2 { if len(msg) != 2 {
helper.LogError(nil, errors.New("Redis subscribe error: got an invalid notification")) helper.LogError(nil, fmt.Errorf("Redis receive error: got an invalid notification: %q", dataStr))
continue continue
} }
key, value := msg[0], msg[1] key, value := msg[0], msg[1]
notifyChanWatchers(key, value) notifyChanWatchers(key, value)
case error: case error:
helper.LogError(nil, fmt.Errorf("Redis subscribe error: %s", v)) helper.LogError(nil, fmt.Errorf("Redis receive error: %s", v))
return // Intermittent error, return nil so that it doesn't wait before reconnect
return nil
} }
} }
} }
func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) {
conn, err := dialer()
if err != nil {
return nil, err
}
// Make sure Redis is actually connected
conn.Do("PING")
if err := conn.Err(); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// Process redis subscriptions // Process redis subscriptions
// //
// NOTE: There Can Only Be One! // NOTE: There Can Only Be One!
...@@ -97,13 +112,19 @@ func Process(reconnect bool) { ...@@ -97,13 +112,19 @@ func Process(reconnect bool) {
for loop { for loop {
loop = reconnect loop = reconnect
log.Println("Connecting to redis") log.Println("Connecting to redis")
conn, err := redisDialFunc()
conn, err := dialPubSub(workerDialFunc)
if err != nil { if err != nil {
helper.LogError(nil, fmt.Errorf("Failed to connect to redis: %s", err)) helper.LogError(nil, fmt.Errorf("Failed to connect to redis: %s", err))
time.Sleep(redisReconnectTimeout.Duration()) time.Sleep(redisReconnectTimeout.Duration())
continue continue
} }
processInner(conn) redisReconnectTimeout.Reset()
if err = processInner(conn); err != nil {
helper.LogError(nil, fmt.Errorf("Failed to process redis-queue: %s", err))
continue
}
} }
} }
......
...@@ -103,7 +103,6 @@ func TestWatchKeyNoChange(t *testing.T) { ...@@ -103,7 +103,6 @@ func TestWatchKeyNoChange(t *testing.T) {
processMessages(1, "something") processMessages(1, "something")
wg.Wait() wg.Wait()
} }
func TestWatchKeyTimeout(t *testing.T) { func TestWatchKeyTimeout(t *testing.T) {
......
...@@ -3,6 +3,8 @@ package redis ...@@ -3,6 +3,8 @@ package redis
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"net/url"
"time" "time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
...@@ -18,10 +20,26 @@ var ( ...@@ -18,10 +20,26 @@ var (
) )
const ( const (
defaultMaxIdle = 1 // Max Idle Connections in the pool.
defaultMaxActive = 1 defaultMaxIdle = 1
// Max Active Connections in the pool.
defaultMaxActive = 1
// Timeout for Read operations on the pool. 1 second is technically overkill,
// it's just for sanity.
defaultReadTimeout = 1 * time.Second defaultReadTimeout = 1 * time.Second
// Timeout for Write operations on the pool. 1 second is technically overkill,
// it's just for sanity.
defaultWriteTimeout = 1 * time.Second
// Timeout before killing Idle connections in the pool. 3 minutes seemed good.
// If you _actually_ hit this timeout often, you should consider turning of
// redis-support since it's not necessary at that point...
defaultIdleTimeout = 3 * time.Minute defaultIdleTimeout = 3 * time.Minute
// KeepAlivePeriod is to keep a TCP connection open for an extended period of
// time without being killed. This is used both in the pool, and in the
// worker-connection.
// See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more
// information.
defaultKeepAlivePeriod = 5 * time.Minute
) )
var ( var (
...@@ -65,37 +83,91 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { ...@@ -65,37 +83,91 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
} }
} }
var redisDialFunc func() (redis.Conn, error) var poolDialFunc func() (redis.Conn, error)
var workerDialFunc func() (redis.Conn, error)
func dialOptionsBuilder(cfg *config.RedisConfig) []redis.DialOption { func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
readTimeout := defaultReadTimeout readTimeout := defaultReadTimeout
if cfg.ReadTimeout != nil { writeTimeout := defaultWriteTimeout
readTimeout = time.Millisecond * time.Duration(*cfg.ReadTimeout)
if cfg != nil {
if cfg.ReadTimeout != nil {
readTimeout = cfg.ReadTimeout.Duration
}
if cfg.WriteTimeout != nil {
writeTimeout = cfg.WriteTimeout.Duration
}
}
return []redis.DialOption{
redis.DialReadTimeout(readTimeout),
redis.DialWriteTimeout(writeTimeout),
}
}
func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption {
var dopts []redis.DialOption
if setTimeouts {
dopts = timeoutDialOptions(cfg)
}
if cfg == nil {
return dopts
} }
dopts := []redis.DialOption{redis.DialReadTimeout(readTimeout)}
if cfg.Password != "" { if cfg.Password != "" {
dopts = append(dopts, redis.DialPassword(cfg.Password)) dopts = append(dopts, redis.DialPassword(cfg.Password))
} }
if cfg.DB != nil {
dopts = append(dopts, redis.DialDatabase(*cfg.DB))
}
return dopts return dopts
} }
// DefaultDialFunc should always used. Only exception is for unit-tests. func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) { return func(network, address string) (net.Conn, error) {
dopts := dialOptionsBuilder(cfg) addr, err := net.ResolveTCPAddr(network, address)
innerDial := func() (redis.Conn, error) { if err != nil {
return redis.Dial(cfg.URL.Scheme, cfg.URL.Host, dopts...) return nil, err
}
tc, err := net.DialTCP(network, nil, addr)
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err := tc.SetKeepAlivePeriod(timeout); err != nil {
return nil, err
}
return tc, nil
} }
if sntnl != nil { }
innerDial = func() (redis.Conn, error) {
address, err := sntnl.MasterAddr() type redisDialerFunc func() (redis.Conn, error)
if err != nil {
return nil, err func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc {
} return func() (redis.Conn, error) {
return redis.Dial("tcp", address, dopts...) address, err := sntnl.MasterAddr()
if err != nil {
return nil, err
}
dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
return redis.Dial("tcp", address, dopts...)
}
}
func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc {
return func() (redis.Conn, error) {
if url.Scheme == "unix" {
return redis.Dial(url.Scheme, url.Path, dopts...)
} }
dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
return redis.Dial(url.Scheme, url.Host, dopts...)
} }
}
func countDialer(dialer redisDialerFunc) redisDialerFunc {
return func() (redis.Conn, error) { return func() (redis.Conn, error) {
c, err := innerDial() c, err := dialer()
if err == nil { if err == nil {
totalConnections.Inc() totalConnections.Inc()
} }
...@@ -103,8 +175,21 @@ func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) { ...@@ -103,8 +175,21 @@ func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) {
} }
} }
// DefaultDialFunc should always used. Only exception is for unit-tests.
func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
keepAlivePeriod := defaultKeepAlivePeriod
if cfg.KeepAlivePeriod != nil {
keepAlivePeriod = cfg.KeepAlivePeriod.Duration
}
dopts := dialOptionsBuilder(cfg, setReadTimeout)
if sntnl != nil {
return countDialer(sentinelDialer(dopts, keepAlivePeriod))
}
return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL))
}
// Configure redis-connection // Configure redis-connection
func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) { func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) {
if cfg == nil { if cfg == nil {
return return
} }
...@@ -117,12 +202,13 @@ func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) { ...@@ -117,12 +202,13 @@ func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) {
maxActive = *cfg.MaxActive maxActive = *cfg.MaxActive
} }
sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel)
redisDialFunc = dialFunc workerDialFunc = dialFunc(cfg, false)
poolDialFunc = dialFunc(cfg, true)
pool = &redis.Pool{ pool = &redis.Pool{
MaxIdle: maxIdle, // Keep at most X hot connections MaxIdle: maxIdle, // Keep at most X hot connections
MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited
IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed
Dial: redisDialFunc, Dial: poolDialFunc,
Wait: true, Wait: true,
} }
if sntnl != nil { if sntnl != nil {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"time" "time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"github.com/garyburd/redigo/redis" "github.com/garyburd/redigo/redis"
"github.com/rafaeljusto/redigomock" "github.com/rafaeljusto/redigomock"
...@@ -17,8 +18,10 @@ import ( ...@@ -17,8 +18,10 @@ import (
func setupMockPool() (*redigomock.Conn, func()) { func setupMockPool() (*redigomock.Conn, func()) {
conn := redigomock.NewConn() conn := redigomock.NewConn()
cfg := &config.RedisConfig{URL: config.TomlURL{}} cfg := &config.RedisConfig{URL: config.TomlURL{}}
Configure(cfg, func() (redis.Conn, error) { Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) {
return conn, nil return func() (redis.Conn, error) {
return conn, nil
}
}) })
return conn, func() { return conn, func() {
pool = nil pool = nil
...@@ -33,7 +36,7 @@ func TestConfigureNoConfig(t *testing.T) { ...@@ -33,7 +36,7 @@ func TestConfigureNoConfig(t *testing.T) {
func TestConfigureMinimalConfig(t *testing.T) { func TestConfigureMinimalConfig(t *testing.T) {
cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""}
Configure(cfg, DefaultDialFunc(cfg)) Configure(cfg, DefaultDialFunc)
if assert.NotNil(t, pool, "Pool should not be nil") { if assert.NotNil(t, pool, "Pool should not be nil") {
assert.Equal(t, 1, pool.MaxIdle) assert.Equal(t, 1, pool.MaxIdle)
assert.Equal(t, 1, pool.MaxActive) assert.Equal(t, 1, pool.MaxActive)
...@@ -43,7 +46,8 @@ func TestConfigureMinimalConfig(t *testing.T) { ...@@ -43,7 +46,8 @@ func TestConfigureMinimalConfig(t *testing.T) {
} }
func TestConfigureFullConfig(t *testing.T) { func TestConfigureFullConfig(t *testing.T) {
i, a, r := 4, 10, 3 i, a := 4, 10
r := config.TomlDuration{Duration: 3}
cfg := &config.RedisConfig{ cfg := &config.RedisConfig{
URL: config.TomlURL{}, URL: config.TomlURL{},
Password: "", Password: "",
...@@ -51,7 +55,7 @@ func TestConfigureFullConfig(t *testing.T) { ...@@ -51,7 +55,7 @@ func TestConfigureFullConfig(t *testing.T) {
MaxActive: &a, MaxActive: &a,
ReadTimeout: &r, ReadTimeout: &r,
} }
Configure(cfg, DefaultDialFunc(cfg)) Configure(cfg, DefaultDialFunc)
if assert.NotNil(t, pool, "Pool should not be nil") { if assert.NotNil(t, pool, "Pool should not be nil") {
assert.Equal(t, i, pool.MaxIdle) assert.Equal(t, i, pool.MaxIdle)
assert.Equal(t, a, pool.MaxActive) assert.Equal(t, a, pool.MaxActive)
...@@ -88,3 +92,51 @@ func TestGetStringFail(t *testing.T) { ...@@ -88,3 +92,51 @@ func TestGetStringFail(t *testing.T) {
_, err := GetString("foobar") _, err := GetString("foobar")
assert.Error(t, err, "Expected error when not connected to redis") assert.Error(t, err, "Expected error when not connected to redis")
} }
func TestSentinelConnNoSentinel(t *testing.T) {
s := sentinelConn("", []config.TomlURL{})
assert.Nil(t, s, "Sentinel without urls should return nil")
}
func TestSentinelConnTwoURLs(t *testing.T) {
urls := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"}
var sentinelUrls []config.TomlURL
for _, url := range urls {
parsedURL := helper.URLMustParse(url)
sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
}
s := sentinelConn("foobar", sentinelUrls)
assert.Equal(t, len(urls), len(s.Addrs))
for i := range urls {
assert.Equal(t, urls[i], s.Addrs[i])
}
}
func TestDialOptionsBuildersPassword(t *testing.T) {
dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false)
assert.Equal(t, 1, len(dopts))
}
func TestDialOptionsBuildersSetTimeouts(t *testing.T) {
dopts := dialOptionsBuilder(nil, true)
assert.Equal(t, 2, len(dopts))
}
func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) {
cfg := &config.RedisConfig{
ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
}
dopts := dialOptionsBuilder(cfg, true)
assert.Equal(t, 2, len(dopts))
}
func TestDialOptionsBuildersSelectDB(t *testing.T) {
db := 3
dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false)
assert.Equal(t, 1, len(dopts))
}
...@@ -133,7 +133,7 @@ func main() { ...@@ -133,7 +133,7 @@ func main() {
cfg.Redis = cfgFromFile.Redis cfg.Redis = cfgFromFile.Redis
redis.Configure(cfg.Redis, redis.DefaultDialFunc(cfg.Redis)) redis.Configure(cfg.Redis, redis.DefaultDialFunc)
go redis.Process(true) go redis.Process(true)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment