Commit 70b9e1f1 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'bjk/vendor_bump' into 'master'

General vendoring cleanup

See merge request gitlab-org/gitlab-workhorse!310
parents 24a49325 6473b6e4
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
"github.com/jpillora/backoff" "github.com/jpillora/backoff"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
......
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"net/url" "net/url"
"time" "time"
sentinel "github.com/FZambia/go-sentinel" "github.com/FZambia/sentinel"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
......
...@@ -4,7 +4,7 @@ import ( ...@@ -4,7 +4,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
"github.com/rafaeljusto/redigomock" "github.com/rafaeljusto/redigomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
......
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
go-sentinel go-sentinel
=========== ===========
Redis Sentinel support for [redigo](https://github.com/garyburd/redigo) library. Redis Sentinel support for [redigo](https://github.com/gomodule/redigo) library.
**API is unstable and can change at any moment** – use with tools like Glide, Godep etc. **API is unstable and can change at any moment** – use with tools like Glide, Godep etc.
Documentation Documentation
------------- -------------
- [API Reference](http://godoc.org/github.com/FZambia/go-sentinel) - [API Reference](http://godoc.org/github.com/FZambia/sentinel)
License License
------- -------
......
...@@ -3,11 +3,12 @@ package sentinel ...@@ -3,11 +3,12 @@ package sentinel
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
) )
// Sentinel provides a way to add high availability (HA) to Redis Pool using // Sentinel provides a way to add high availability (HA) to Redis Pool using
...@@ -90,9 +91,8 @@ type NoSentinelsAvailable struct { ...@@ -90,9 +91,8 @@ type NoSentinelsAvailable struct {
func (ns NoSentinelsAvailable) Error() string { func (ns NoSentinelsAvailable) Error() string {
if ns.lastError != nil { if ns.lastError != nil {
return fmt.Sprintf("redigo: no sentinels available; last error: %s", ns.lastError.Error()) return fmt.Sprintf("redigo: no sentinels available; last error: %s", ns.lastError.Error())
} else {
return fmt.Sprintf("redigo: no sentinels available")
} }
return fmt.Sprintf("redigo: no sentinels available")
} }
// putToTop puts Sentinel address to the top of address list - this means // putToTop puts Sentinel address to the top of address list - this means
...@@ -249,10 +249,10 @@ func (s *Sentinel) MasterAddr() (string, error) { ...@@ -249,10 +249,10 @@ func (s *Sentinel) MasterAddr() (string, error) {
return res.(string), nil return res.(string), nil
} }
// SlaveAddrs returns a slice with known slaves of current master instance. // SlaveAddrs returns a slice with known slave addresses of current master instance.
func (s *Sentinel) SlaveAddrs() ([]string, error) { func (s *Sentinel) SlaveAddrs() ([]string, error) {
res, err := s.doUntilSuccess(func(c redis.Conn) (interface{}, error) { res, err := s.doUntilSuccess(func(c redis.Conn) (interface{}, error) {
return queryForSlaves(c, s.MasterName) return queryForSlaveAddrs(c, s.MasterName)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -260,6 +260,34 @@ func (s *Sentinel) SlaveAddrs() ([]string, error) { ...@@ -260,6 +260,34 @@ func (s *Sentinel) SlaveAddrs() ([]string, error) {
return res.([]string), nil return res.([]string), nil
} }
// Slave represents a Redis slave instance which is known by Sentinel.
type Slave struct {
ip string
port string
flags string
}
// Addr returns an address of slave.
func (s *Slave) Addr() string {
return net.JoinHostPort(s.ip, s.port)
}
// Available returns if slave is in working state at moment based on information in slave flags.
func (s *Slave) Available() bool {
return !strings.Contains(s.flags, "disconnected") && !strings.Contains(s.flags, "s_down")
}
// Slaves returns a slice with known slaves of master instance.
func (s *Sentinel) Slaves() ([]*Slave, error) {
res, err := s.doUntilSuccess(func(c redis.Conn) (interface{}, error) {
return queryForSlaves(c, s.MasterName)
})
if err != nil {
return nil, err
}
return res.([]*Slave), nil
}
// SentinelAddrs returns a slice of known Sentinel addresses Sentinel server aware of. // SentinelAddrs returns a slice of known Sentinel addresses Sentinel server aware of.
func (s *Sentinel) SentinelAddrs() ([]string, error) { func (s *Sentinel) SentinelAddrs() ([]string, error) {
res, err := s.doUntilSuccess(func(c redis.Conn) (interface{}, error) { res, err := s.doUntilSuccess(func(c redis.Conn) (interface{}, error) {
...@@ -332,22 +360,42 @@ func queryForMaster(conn redis.Conn, masterName string) (string, error) { ...@@ -332,22 +360,42 @@ func queryForMaster(conn redis.Conn, masterName string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
masterAddr := strings.Join(res, ":") if len(res) < 2 {
return "", errors.New("redigo: malformed get-master-addr-by-name reply")
}
masterAddr := net.JoinHostPort(res[0], res[1])
return masterAddr, nil return masterAddr, nil
} }
func queryForSlaves(conn redis.Conn, masterName string) ([]string, error) { func queryForSlaveAddrs(conn redis.Conn, masterName string) ([]string, error) {
slaves, err := queryForSlaves(conn, masterName)
if err != nil {
return nil, err
}
slaveAddrs := make([]string, 0)
for _, slave := range slaves {
slaveAddrs = append(slaveAddrs, slave.Addr())
}
return slaveAddrs, nil
}
func queryForSlaves(conn redis.Conn, masterName string) ([]*Slave, error) {
res, err := redis.Values(conn.Do("SENTINEL", "slaves", masterName)) res, err := redis.Values(conn.Do("SENTINEL", "slaves", masterName))
if err != nil { if err != nil {
return nil, err return nil, err
} }
slaves := make([]string, 0) slaves := make([]*Slave, 0)
for _, a := range res { for _, a := range res {
sm, err := redis.StringMap(a, err) sm, err := redis.StringMap(a, err)
if err != nil { if err != nil {
return slaves, err return slaves, err
} }
slaves = append(slaves, fmt.Sprintf("%s:%s", sm["ip"], sm["port"])) slave := &Slave{
ip: sm["ip"],
port: sm["port"],
flags: sm["flags"],
}
slaves = append(slaves, slave)
} }
return slaves, nil return slaves, nil
} }
......
...@@ -77,15 +77,20 @@ func NewHighBiased(epsilon float64) *Stream { ...@@ -77,15 +77,20 @@ func NewHighBiased(epsilon float64) *Stream {
// is guaranteed to be within (Quantile±Epsilon). // is guaranteed to be within (Quantile±Epsilon).
// //
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error properties. // See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error properties.
func NewTargeted(targets map[float64]float64) *Stream { func NewTargeted(targetMap map[float64]float64) *Stream {
// Convert map to slice to avoid slow iterations on a map.
// ƒ is called on the hot path, so converting the map to a slice
// beforehand results in significant CPU savings.
targets := targetMapToSlice(targetMap)
ƒ := func(s *stream, r float64) float64 { ƒ := func(s *stream, r float64) float64 {
var m = math.MaxFloat64 var m = math.MaxFloat64
var f float64 var f float64
for quantile, epsilon := range targets { for _, t := range targets {
if quantile*s.n <= r { if t.quantile*s.n <= r {
f = (2 * epsilon * r) / quantile f = (2 * t.epsilon * r) / t.quantile
} else { } else {
f = (2 * epsilon * (s.n - r)) / (1 - quantile) f = (2 * t.epsilon * (s.n - r)) / (1 - t.quantile)
} }
if f < m { if f < m {
m = f m = f
...@@ -96,6 +101,25 @@ func NewTargeted(targets map[float64]float64) *Stream { ...@@ -96,6 +101,25 @@ func NewTargeted(targets map[float64]float64) *Stream {
return newStream(ƒ) return newStream(ƒ)
} }
type target struct {
quantile float64
epsilon float64
}
func targetMapToSlice(targetMap map[float64]float64) []target {
targets := make([]target, 0, len(targetMap))
for quantile, epsilon := range targetMap {
t := target{
quantile: quantile,
epsilon: epsilon,
}
targets = append(targets, t)
}
return targets
}
// Stream computes quantiles for a stream of float64s. It is not thread-safe by // Stream computes quantiles for a stream of float64s. It is not thread-safe by
// design. Take care when using across multiple goroutines. // design. Take care when using across multiple goroutines.
type Stream struct { type Stream struct {
......
This source diff could not be displayed because it is too large. You can view the blob instead.
all: build test lint
build: build: ## build and lint
go build ./... go build ./...
gometalinter \
--vendor \
--vendored-linters \
--deadline=60s \
--disable-all \
--enable=goimports \
--enable=vetshadow \
--enable=varcheck \
--enable=structcheck \
--enable=deadcode \
--enable=ineffassign \
--enable=unconvert \
--enable=goconst \
--enable=golint \
--enable=gosimple \
--enable=gofmt \
--enable=misspell \
--enable=staticcheck \
.
test: ## just test
go test -cover .
test: clean: ## cleanup
go test ./...
lint:
golint ./...
gofmt -w -s . ./example*
goimports -w . ./example*
clean:
rm -f *~ ./example*/*~
rm -f ./example1/example1 rm -f ./example1/example1
rm -f ./example2/example2 rm -f ./example2/example2
go clean ./... go clean ./...
git gc git gc
ci: build test lint # https://www.client9.com/self-documenting-makefiles/
help:
docker-ci: @awk -F ':|##' '/^[^\t].+?:.*?##/ {\
docker run --rm \ printf "\033[36m%-30s\033[0m %s\n", $$1, $$NF \
-e COVERALLS_REPO_TOKEN=$(COVERALLS_REPO_TOKEN) \ }' $(MAKEFILE_LIST)
-v $(PWD):/go/src/github.com/client9/reopen \ .DEFAULT_GOAL=help
-w /go/src/github.com/client9/reopen \ .PHONY=help
nickg/golang-dev-docker \
make ci
.PHONY: ci docker-ci
[![Build Status](https://travis-ci.org/client9/reopen.svg)](https://travis-ci.org/client9/reopen) [![Go Report Card](http://goreportcard.com/badge/client9/reopen)](http://goreportcard.com/report/client9/reopen) [![GoDoc](https://godoc.org/github.com/client9/reopen?status.svg)](https://godoc.org/github.com/client9/reopen) [![Coverage](http://gocover.io/_badge/github.com/client9/reopen)](http://gocover.io/github.com/client9/reopen) [![license](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](https://raw.githubusercontent.com/client9/reopen/master/LICENSE) [![Build Status](https://travis-ci.org/client9/reopen.svg)](https://travis-ci.org/client9/reopen) [![Go Report Card](http://goreportcard.com/badge/client9/reopen)](http://goreportcard.com/report/client9/reopen) [![GoDoc](https://godoc.org/github.com/client9/reopen?status.svg)](https://godoc.org/github.com/client9/reopen) [![license](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](https://raw.githubusercontent.com/client9/reopen/master/LICENSE)
Makes a standard os.File a "reopenable writer" and allows SIGHUP signals Makes a standard os.File a "reopenable writer" and allows SIGHUP signals
to reopen log files, as needed by to reopen log files, as needed by
......
...@@ -98,19 +98,21 @@ func NewFileWriterMode(name string, mode os.FileMode) (*FileWriter, error) { ...@@ -98,19 +98,21 @@ func NewFileWriterMode(name string, mode os.FileMode) (*FileWriter, error) {
// BufferedFileWriter is buffer writer than can be reopned // BufferedFileWriter is buffer writer than can be reopned
type BufferedFileWriter struct { type BufferedFileWriter struct {
mu sync.Mutex mu sync.Mutex
OrigWriter *FileWriter quitChan chan bool
BufWriter *bufio.Writer done bool
origWriter *FileWriter
bufWriter *bufio.Writer
} }
// Reopen implement Reopener // Reopen implement Reopener
func (bw *BufferedFileWriter) Reopen() error { func (bw *BufferedFileWriter) Reopen() error {
bw.mu.Lock() bw.mu.Lock()
bw.BufWriter.Flush() bw.bufWriter.Flush()
// use non-mutex version since we are using this one // use non-mutex version since we are using this one
err := bw.OrigWriter.reopen() err := bw.origWriter.reopen()
bw.BufWriter.Reset(io.Writer(bw.OrigWriter)) bw.bufWriter.Reset(io.Writer(bw.origWriter))
bw.mu.Unlock() bw.mu.Unlock()
return err return err
...@@ -118,9 +120,11 @@ func (bw *BufferedFileWriter) Reopen() error { ...@@ -118,9 +120,11 @@ func (bw *BufferedFileWriter) Reopen() error {
// Close flushes the internal buffer and closes the destination file // Close flushes the internal buffer and closes the destination file
func (bw *BufferedFileWriter) Close() error { func (bw *BufferedFileWriter) Close() error {
bw.quitChan <- true
bw.mu.Lock() bw.mu.Lock()
bw.BufWriter.Flush() bw.done = true
bw.OrigWriter.f.Close() bw.bufWriter.Flush()
bw.origWriter.f.Close()
bw.mu.Unlock() bw.mu.Unlock()
return nil return nil
} }
...@@ -128,27 +132,40 @@ func (bw *BufferedFileWriter) Close() error { ...@@ -128,27 +132,40 @@ func (bw *BufferedFileWriter) Close() error {
// Write implements io.Writer (and reopen.Writer) // Write implements io.Writer (and reopen.Writer)
func (bw *BufferedFileWriter) Write(p []byte) (int, error) { func (bw *BufferedFileWriter) Write(p []byte) (int, error) {
bw.mu.Lock() bw.mu.Lock()
n, err := bw.BufWriter.Write(p) n, err := bw.bufWriter.Write(p)
// Special Case... if the used space in the buffer is LESS than // Special Case... if the used space in the buffer is LESS than
// the input, then we did a flush in the middle of the line // the input, then we did a flush in the middle of the line
// and the full log line was not sent on its way. // and the full log line was not sent on its way.
if bw.BufWriter.Buffered() < len(p) { if bw.bufWriter.Buffered() < len(p) {
bw.BufWriter.Flush() bw.bufWriter.Flush()
} }
bw.mu.Unlock() bw.mu.Unlock()
return n, err return n, err
} }
// Flush flushes the buffer.
func (bw *BufferedFileWriter) Flush() {
bw.mu.Lock()
// could add check if bw.done already
// should never happen
bw.bufWriter.Flush()
bw.origWriter.f.Sync()
bw.mu.Unlock()
}
// flushDaemon periodically flushes the log file buffers. // flushDaemon periodically flushes the log file buffers.
// props to glog func (bw *BufferedFileWriter) flushDaemon(interval time.Duration) {
func (bw *BufferedFileWriter) flushDaemon() { ticker := time.NewTicker(interval)
for range time.NewTicker(flushInterval).C { for {
bw.mu.Lock() select {
bw.BufWriter.Flush() case <-bw.quitChan:
bw.OrigWriter.f.Sync() ticker.Stop()
bw.mu.Unlock() return
case <-ticker.C:
bw.Flush()
}
} }
} }
...@@ -157,13 +174,19 @@ const flushInterval = 30 * time.Second ...@@ -157,13 +174,19 @@ const flushInterval = 30 * time.Second
// NewBufferedFileWriter opens a buffered file that is periodically // NewBufferedFileWriter opens a buffered file that is periodically
// flushed. // flushed.
// TODO: allow size and interval to be passed in.
func NewBufferedFileWriter(w *FileWriter) *BufferedFileWriter { func NewBufferedFileWriter(w *FileWriter) *BufferedFileWriter {
return NewBufferedFileWriterSize(w, bufferSize, flushInterval)
}
// NewBufferedFileWriterSize opens a buffered file with the given size that is periodically
// flushed on the given interval.
func NewBufferedFileWriterSize(w *FileWriter, size int, flush time.Duration) *BufferedFileWriter {
bw := BufferedFileWriter{ bw := BufferedFileWriter{
OrigWriter: w, quitChan: make(chan bool, 1),
BufWriter: bufio.NewWriterSize(w, bufferSize), origWriter: w,
bufWriter: bufio.NewWriterSize(w, size),
} }
go bw.flushDaemon() go bw.flushDaemon(flush)
return &bw return &bw
} }
...@@ -218,8 +241,7 @@ func (nopReopenWriteCloser) Close() error { ...@@ -218,8 +241,7 @@ func (nopReopenWriteCloser) Close() error {
} }
// NopWriter turns a normal writer into a ReopenWriter // NopWriter turns a normal writer into a ReopenWriter
// by doing a NOP on Reopen // by doing a NOP on Reopen. See https://en.wikipedia.org/wiki/NOP
// TODO: better name
func NopWriter(w io.Writer) WriteCloser { func NopWriter(w io.Writer) WriteCloser {
return nopReopenWriteCloser{w} return nopReopenWriteCloser{w}
} }
......
...@@ -2,7 +2,7 @@ ISC License ...@@ -2,7 +2,7 @@ ISC License
Copyright (c) 2012-2016 Dave Collins <dave@davec.name> Copyright (c) 2012-2016 Dave Collins <dave@davec.name>
Permission to use, copy, modify, and distribute this software for any Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies. copyright notice and this permission notice appear in all copies.
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
// when the code is not running on Google App Engine, compiled by GopherJS, and // when the code is not running on Google App Engine, compiled by GopherJS, and
// "-tags safe" is not added to the go build command line. The "disableunsafe" // "-tags safe" is not added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used. // tag is deprecated and thus should not be used.
// +build !js,!appengine,!safe,!disableunsafe // Go versions prior to 1.4 are disabled because they use a different layout
// for interfaces which make the implementation of unsafeReflectValue more complex.
// +build !js,!appengine,!safe,!disableunsafe,go1.4
package spew package spew
...@@ -34,80 +36,49 @@ const ( ...@@ -34,80 +36,49 @@ const (
ptrSize = unsafe.Sizeof((*byte)(nil)) ptrSize = unsafe.Sizeof((*byte)(nil))
) )
type flag uintptr
var ( var (
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the // flagRO indicates whether the value field of a reflect.Value
// internal reflect.Value fields. These values are valid before golang // is read-only.
// commit ecccf07e7f9d which changed the format. The are also valid flagRO flag
// after commit 82f48826c6c7 which changed the format again to mirror
// the original format. Code in the init function updates these offsets // flagAddr indicates whether the address of the reflect.Value's
// as necessary. // value may be taken.
offsetPtr = uintptr(ptrSize) flagAddr flag
offsetScalar = uintptr(0)
offsetFlag = uintptr(ptrSize * 2)
// flagKindWidth and flagKindShift indicate various bits that the
// reflect package uses internally to track kind information.
//
// flagRO indicates whether or not the value field of a reflect.Value is
// read-only.
//
// flagIndir indicates whether the value field of a reflect.Value is
// the actual data or a pointer to the data.
//
// These values are valid before golang commit 90a7c3c86944 which
// changed their positions. Code in the init function updates these
// flags as necessary.
flagKindWidth = uintptr(5)
flagKindShift = uintptr(flagKindWidth - 1)
flagRO = uintptr(1 << 0)
flagIndir = uintptr(1 << 1)
) )
func init() { // flagKindMask holds the bits that make up the kind
// Older versions of reflect.Value stored small integers directly in the // part of the flags field. In all the supported versions,
// ptr field (which is named val in the older versions). Versions // it is in the lower 5 bits.
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named const flagKindMask = flag(0x1f)
// scalar for this purpose which unfortunately came before the flag
// field, so the offset of the flag field is different for those
// versions.
//
// This code constructs a new reflect.Value from a known small integer
// and checks if the size of the reflect.Value struct indicates it has
// the scalar field. When it does, the offsets are updated accordingly.
vv := reflect.ValueOf(0xf00)
if unsafe.Sizeof(vv) == (ptrSize * 4) {
offsetScalar = ptrSize * 2
offsetFlag = ptrSize * 3
}
// Commit 90a7c3c86944 changed the flag positions such that the low // Different versions of Go have used different
// order bits are the kind. This code extracts the kind from the flags // bit layouts for the flags type. This table
// field and ensures it's the correct type. When it's not, the flag // records the known combinations.
// order has been changed to the newer format, so the flags are updated var okFlags = []struct {
// accordingly. ro, addr flag
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag) }{{
upfv := *(*uintptr)(upf) // From Go 1.4 to 1.5
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift) ro: 1 << 5,
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) { addr: 1 << 7,
flagKindShift = 0 }, {
flagRO = 1 << 5 // Up to Go tip.
flagIndir = 1 << 6 ro: 1<<5 | 1<<6,
addr: 1 << 8,
// Commit adf9b30e5594 modified the flags to separate the }}
// flagRO flag into two bits which specifies whether or not the
// field is embedded. This causes flagIndir to move over a bit var flagValOffset = func() uintptr {
// and means that flagRO is the combination of either of the field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
// original flagRO bit and the new bit. if !ok {
// panic("reflect.Value has no flag field")
// This code detects the change by extracting what used to be
// the indirect bit to ensure it's set. When it's not, the flag
// order has been changed to the newer format, so the flags are
// updated accordingly.
if upfv&flagIndir == 0 {
flagRO = 3 << 5
flagIndir = 1 << 7
}
} }
return field.Offset
}()
// flagField returns a pointer to the flag field of a reflect.Value.
func flagField(v *reflect.Value) *flag {
return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset))
} }
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses // unsafeReflectValue converts the passed reflect.Value into a one that bypasses
...@@ -119,34 +90,56 @@ func init() { ...@@ -119,34 +90,56 @@ func init() {
// This allows us to check for implementations of the Stringer and error // This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and // interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields. // inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) { func unsafeReflectValue(v reflect.Value) reflect.Value {
indirects := 1 if !v.IsValid() || (v.CanInterface() && v.CanAddr()) {
vt := v.Type() return v
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
if rvf&flagIndir != 0 {
vt = reflect.PtrTo(v.Type())
indirects++
} else if offsetScalar != 0 {
// The value is in the scalar field when it's not one of the
// reference types.
switch vt.Kind() {
case reflect.Uintptr:
case reflect.Chan:
case reflect.Func:
case reflect.Map:
case reflect.Ptr:
case reflect.UnsafePointer:
default:
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
offsetScalar)
}
} }
flagFieldPtr := flagField(&v)
*flagFieldPtr &^= flagRO
*flagFieldPtr |= flagAddr
return v
}
pv := reflect.NewAt(vt, upv) // Sanity checks against future reflect package changes
rv = pv // to the type or semantics of the Value.flag field.
for i := 0; i < indirects; i++ { func init() {
rv = rv.Elem() field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() {
panic("reflect.Value flag field has changed kind")
}
type t0 int
var t struct {
A t0
// t0 will have flagEmbedRO set.
t0
// a will have flagStickyRO set
a t0
}
vA := reflect.ValueOf(t).FieldByName("A")
va := reflect.ValueOf(t).FieldByName("a")
vt0 := reflect.ValueOf(t).FieldByName("t0")
// Infer flagRO from the difference between the flags
// for the (otherwise identical) fields in t.
flagPublic := *flagField(&vA)
flagWithRO := *flagField(&va) | *flagField(&vt0)
flagRO = flagPublic ^ flagWithRO
// Infer flagAddr from the difference between a value
// taken from a pointer and not.
vPtrA := reflect.ValueOf(&t).Elem().FieldByName("A")
flagNoPtr := *flagField(&vA)
flagPtr := *flagField(&vPtrA)
flagAddr = flagNoPtr ^ flagPtr
// Check that the inferred flags tally with one of the known versions.
for _, f := range okFlags {
if flagRO == f.ro && flagAddr == f.addr {
return
}
} }
return rv panic("reflect.Value read-only flag has changed semantics")
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
// when the code is running on Google App Engine, compiled by GopherJS, or // when the code is running on Google App Engine, compiled by GopherJS, or
// "-tags safe" is added to the go build command line. The "disableunsafe" // "-tags safe" is added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used. // tag is deprecated and thus should not be used.
// +build js appengine safe disableunsafe // +build js appengine safe disableunsafe !go1.4
package spew package spew
......
...@@ -180,7 +180,7 @@ func printComplex(w io.Writer, c complex128, floatPrecision int) { ...@@ -180,7 +180,7 @@ func printComplex(w io.Writer, c complex128, floatPrecision int) {
w.Write(closeParenBytes) w.Write(closeParenBytes)
} }
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x' // printHexPtr outputs a uintptr formatted as hexadecimal with a leading '0x'
// prefix to Writer w. // prefix to Writer w.
func printHexPtr(w io.Writer, p uintptr) { func printHexPtr(w io.Writer, p uintptr) {
// Null pointer. // Null pointer.
......
...@@ -35,16 +35,16 @@ var ( ...@@ -35,16 +35,16 @@ var (
// cCharRE is a regular expression that matches a cgo char. // cCharRE is a regular expression that matches a cgo char.
// It is used to detect character arrays to hexdump them. // It is used to detect character arrays to hexdump them.
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$") cCharRE = regexp.MustCompile(`^.*\._Ctype_char$`)
// cUnsignedCharRE is a regular expression that matches a cgo unsigned // cUnsignedCharRE is a regular expression that matches a cgo unsigned
// char. It is used to detect unsigned character arrays to hexdump // char. It is used to detect unsigned character arrays to hexdump
// them. // them.
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$") cUnsignedCharRE = regexp.MustCompile(`^.*\._Ctype_unsignedchar$`)
// cUint8tCharRE is a regular expression that matches a cgo uint8_t. // cUint8tCharRE is a regular expression that matches a cgo uint8_t.
// It is used to detect uint8_t arrays to hexdump them. // It is used to detect uint8_t arrays to hexdump them.
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$") cUint8tCharRE = regexp.MustCompile(`^.*\._Ctype_uint8_t$`)
) )
// dumpState contains information about the state of a dump operation. // dumpState contains information about the state of a dump operation.
...@@ -143,10 +143,10 @@ func (d *dumpState) dumpPtr(v reflect.Value) { ...@@ -143,10 +143,10 @@ func (d *dumpState) dumpPtr(v reflect.Value) {
// Display dereferenced value. // Display dereferenced value.
d.w.Write(openParenBytes) d.w.Write(openParenBytes)
switch { switch {
case nilFound == true: case nilFound:
d.w.Write(nilAngleBytes) d.w.Write(nilAngleBytes)
case cycleFound == true: case cycleFound:
d.w.Write(circularBytes) d.w.Write(circularBytes)
default: default:
......
...@@ -182,10 +182,10 @@ func (f *formatState) formatPtr(v reflect.Value) { ...@@ -182,10 +182,10 @@ func (f *formatState) formatPtr(v reflect.Value) {
// Display dereferenced value. // Display dereferenced value.
switch { switch {
case nilFound == true: case nilFound:
f.fs.Write(nilAngleBytes) f.fs.Write(nilAngleBytes)
case cycleFound == true: case cycleFound:
f.fs.Write(circularShortBytes) f.fs.Write(circularShortBytes)
default: default:
......
...@@ -56,8 +56,9 @@ This simple parsing example: ...@@ -56,8 +56,9 @@ This simple parsing example:
is directly mapped to: is directly mapped to:
```go ```go
if token, err := request.ParseFromRequest(tokenString, request.OAuth2Extractor, req, keyLookupFunc); err == nil { if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"]) claims := token.Claims.(jwt.MapClaims)
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
} }
``` ```
......
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html) # jwt-go
[![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go) [![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go)
[![GoDoc](https://godoc.org/github.com/dgrijalva/jwt-go?status.svg)](https://godoc.org/github.com/dgrijalva/jwt-go)
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
**BREAKING CHANGES:*** Version 3.0.0 is here. It includes _a lot_ of changes including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code. **NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3.
**NOTICE:** A vulnerability in JWT was [recently published](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). As this library doesn't force users to validate the `alg` is what they expected, it's possible your usage is effected. There will be an update soon to remedy this, and it will likey require backwards-incompatible changes to the API. In the short term, please make sure your implementation verifies the `alg` is what you expect. **SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail.
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided.
## What the heck is a JWT? ## What the heck is a JWT?
...@@ -25,8 +29,8 @@ This library supports the parsing and verification as well as the generation and ...@@ -25,8 +29,8 @@ This library supports the parsing and verification as well as the generation and
See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage: See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage:
* [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example_Parse_hmac) * [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-Parse--Hmac)
* [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example_New_hmac) * [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-New--Hmac)
* [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples) * [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples)
## Extensions ## Extensions
...@@ -37,7 +41,7 @@ Here's an example of an extension that integrates with the Google App Engine sig ...@@ -37,7 +41,7 @@ Here's an example of an extension that integrates with the Google App Engine sig
## Compliance ## Compliance
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences: This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key. * In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
...@@ -47,7 +51,10 @@ This library is considered production ready. Feedback and feature requests are ...@@ -47,7 +51,10 @@ This library is considered production ready. Feedback and feature requests are
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases). This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v2`. It will do the right thing WRT semantic versioning. While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning.
**BREAKING CHANGES:***
* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
## Usage Tips ## Usage Tips
...@@ -68,18 +75,26 @@ Symmetric signing methods, such as HSA, use only a single secret. This is probab ...@@ -68,18 +75,26 @@ Symmetric signing methods, such as HSA, use only a single secret. This is probab
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification. Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
### Signing Methods and Key Types
Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones:
* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation
* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation
* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation
### JWT and OAuth ### JWT and OAuth
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication. It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
Without going too far down the rabbit hole, here's a description of the interaction of these technologies: Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth. * OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token. * OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL. * Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
## More ## More
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go). Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in to documentation. The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation.
## `jwt-go` Version History ## `jwt-go` Version History
#### 3.2.0
* Added method `ParseUnverified` to allow users to split up the tasks of parsing and validation
* HMAC signing method returns `ErrInvalidKeyType` instead of `ErrInvalidKey` where appropriate
* Added options to `request.ParseFromRequest`, which allows for an arbitrary list of modifiers to parsing behavior. Initial set include `WithClaims` and `WithParser`. Existing usage of this function will continue to work as before.
* Deprecated `ParseFromRequestWithClaims` to simplify API in the future.
#### 3.1.0
* Improvements to `jwt` command line tool
* Added `SkipClaimsValidation` option to `Parser`
* Documentation updates
#### 3.0.0 #### 3.0.0
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code * **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code
......
...@@ -14,6 +14,7 @@ var ( ...@@ -14,6 +14,7 @@ var (
) )
// Implements the ECDSA family of signing methods signing methods // Implements the ECDSA family of signing methods signing methods
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
type SigningMethodECDSA struct { type SigningMethodECDSA struct {
Name string Name string
Hash crypto.Hash Hash crypto.Hash
......
...@@ -51,13 +51,9 @@ func (e ValidationError) Error() string { ...@@ -51,13 +51,9 @@ func (e ValidationError) Error() string {
} else { } else {
return "token is invalid" return "token is invalid"
} }
return e.Inner.Error()
} }
// No errors // No errors
func (e *ValidationError) valid() bool { func (e *ValidationError) valid() bool {
if e.Errors > 0 { return e.Errors == 0
return false
}
return true
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
) )
// Implements the HMAC-SHA family of signing methods signing methods // Implements the HMAC-SHA family of signing methods signing methods
// Expects key type of []byte for both signing and validation
type SigningMethodHMAC struct { type SigningMethodHMAC struct {
Name string Name string
Hash crypto.Hash Hash crypto.Hash
...@@ -90,5 +91,5 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, ...@@ -90,5 +91,5 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string,
return EncodeSegment(hasher.Sum(nil)), nil return EncodeSegment(hasher.Sum(nil)), nil
} }
return "", ErrInvalidKey return "", ErrInvalidKeyType
} }
...@@ -8,8 +8,9 @@ import ( ...@@ -8,8 +8,9 @@ import (
) )
type Parser struct { type Parser struct {
ValidMethods []string // If populated, only these methods will be considered valid ValidMethods []string // If populated, only these methods will be considered valid
UseJSONNumber bool // Use JSON Number format in JSON decoder UseJSONNumber bool // Use JSON Number format in JSON decoder
SkipClaimsValidation bool // Skip claims validation during token parsing
} }
// Parse, validate, and return a token. // Parse, validate, and return a token.
...@@ -20,55 +21,9 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { ...@@ -20,55 +21,9 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
} }
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
parts := strings.Split(tokenString, ".") token, parts, err := p.ParseUnverified(tokenString, claims)
if len(parts) != 3 {
return nil, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
var err error
token := &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.UseJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil { if err != nil {
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} return token, err
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {
return token, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
} }
// Verify signing method is in the required set // Verify signing method is in the required set
...@@ -95,20 +50,25 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf ...@@ -95,20 +50,25 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
} }
if key, err = keyFunc(token); err != nil { if key, err = keyFunc(token); err != nil {
// keyFunc returned an error // keyFunc returned an error
if ve, ok := err.(*ValidationError); ok {
return token, ve
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable} return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
} }
vErr := &ValidationError{} vErr := &ValidationError{}
// Validate Claims // Validate Claims
if err := token.Claims.Valid(); err != nil { if !p.SkipClaimsValidation {
if err := token.Claims.Valid(); err != nil {
// If the Claims Valid returned an error, check if it is a validation error,
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set // If the Claims Valid returned an error, check if it is a validation error,
if e, ok := err.(*ValidationError); !ok { // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} if e, ok := err.(*ValidationError); !ok {
} else { vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
vErr = e } else {
vErr = e
}
} }
} }
...@@ -126,3 +86,63 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf ...@@ -126,3 +86,63 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
return token, vErr return token, vErr
} }
// WARNING: Don't use this method unless you know what you're doing
//
// This method parses the token but doesn't validate the signature. It's only
// ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from
// it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
token = &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
}
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.UseJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
}
return token, parts, nil
}
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
) )
// Implements the RSA family of signing methods signing methods // Implements the RSA family of signing methods signing methods
// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation
type SigningMethodRSA struct { type SigningMethodRSA struct {
Name string Name string
Hash crypto.Hash Hash crypto.Hash
...@@ -44,7 +45,7 @@ func (m *SigningMethodRSA) Alg() string { ...@@ -44,7 +45,7 @@ func (m *SigningMethodRSA) Alg() string {
} }
// Implements the Verify method from SigningMethod // Implements the Verify method from SigningMethod
// For this signing method, must be an rsa.PublicKey structure. // For this signing method, must be an *rsa.PublicKey structure.
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
var err error var err error
...@@ -73,7 +74,7 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface ...@@ -73,7 +74,7 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
} }
// Implements the Sign method from SigningMethod // Implements the Sign method from SigningMethod
// For this signing method, must be an rsa.PrivateKey structure. // For this signing method, must be an *rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) { func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
var rsaKey *rsa.PrivateKey var rsaKey *rsa.PrivateKey
var ok bool var ok bool
......
...@@ -39,6 +39,38 @@ func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) { ...@@ -39,6 +39,38 @@ func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
return pkey, nil return pkey, nil
} }
// Parse PEM encoded PKCS1 or PKCS8 private key protected with password
func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey interface{}
var blockDecrypted []byte
if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil {
return nil, err
}
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 public key // Parse PEM encoded PKCS1 or PKCS8 public key
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) { func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
var err error var err error
......
# raven [![Build Status](https://travis-ci.org/getsentry/raven-go.png?branch=master)](https://travis-ci.org/getsentry/raven-go) # raven
[![Build Status](https://api.travis-ci.org/getsentry/raven-go.svg?branch=master)](https://travis-ci.org/getsentry/raven-go)
[![Go Report Card](https://goreportcard.com/badge/github.com/getsentry/raven-go)](https://goreportcard.com/report/github.com/getsentry/raven-go)
[![GoDoc](https://godoc.org/github.com/getsentry/raven-go?status.svg)](https://godoc.org/github.com/getsentry/raven-go)
raven is a Go client for the [Sentry](https://github.com/getsentry/sentry) raven is a Go client for the [Sentry](https://github.com/getsentry/sentry)
event/error logging system. event/error logging system.
- [**API Documentation**](https://godoc.org/github.com/getsentry/raven-go) - [**API Documentation**](https://godoc.org/github.com/getsentry/raven-go)
- [**Usage and Examples**](https://docs.getsentry.com/hosted/clients/go/) - [**Usage and Examples**](https://docs.sentry.io/clients/go/)
## Installation ## Installation
......
package raven
type causer interface {
Cause() error
}
type errWrappedWithExtra struct {
err error
extraInfo map[string]interface{}
}
func (ewx *errWrappedWithExtra) Error() string {
return ewx.err.Error()
}
func (ewx *errWrappedWithExtra) Cause() error {
return ewx.err
}
func (ewx *errWrappedWithExtra) ExtraInfo() Extra {
return ewx.extraInfo
}
// Adds extra data to an error before reporting to Sentry
func WrapWithExtra(err error, extraInfo map[string]interface{}) error {
return &errWrappedWithExtra{
err: err,
extraInfo: extraInfo,
}
}
type ErrWithExtra interface {
Error() string
Cause() error
ExtraInfo() Extra
}
// Iteratively fetches all the Extra data added to an error,
// and it's underlying errors. Extra data defined first is
// respected, and is not overridden when extracting.
func extractExtra(err error) Extra {
extra := Extra{}
currentErr := err
for currentErr != nil {
if errWithExtra, ok := currentErr.(ErrWithExtra); ok {
for k, v := range errWithExtra.ExtraInfo() {
extra[k] = v
}
}
if errWithCause, ok := currentErr.(causer); ok {
currentErr = errWithCause.Cause()
} else {
currentErr = nil
}
}
return extra
}
...@@ -39,3 +39,12 @@ func (e *Exception) Culprit() string { ...@@ -39,3 +39,12 @@ func (e *Exception) Culprit() string {
} }
return e.Stacktrace.Culprit() return e.Stacktrace.Culprit()
} }
// Exceptions allows for chained errors
// https://docs.sentry.io/clientdev/interfaces/exception/
type Exceptions struct {
// Required
Values []*Exception `json:"values"`
}
func (es Exceptions) Class() string { return "exception" }
...@@ -28,6 +28,7 @@ func NewHttp(req *http.Request) *Http { ...@@ -28,6 +28,7 @@ func NewHttp(req *http.Request) *Http {
for k, v := range req.Header { for k, v := range req.Header {
h.Headers[k] = strings.Join(v, ",") h.Headers[k] = strings.Join(v, ",")
} }
h.Headers["Host"] = req.Host
return h return h
} }
...@@ -68,17 +69,31 @@ func (h *Http) Class() string { return "request" } ...@@ -68,17 +69,31 @@ func (h *Http) Class() string { return "request" }
// ... // ...
// })) // }))
func RecoveryHandler(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func RecoveryHandler(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return Recoverer(http.HandlerFunc(handler)).ServeHTTP
}
// Recovery handler to wrap the stdlib net/http Mux.
// Example:
// mux := http.NewServeMux
// ...
// http.Handle("/", raven.Recoverer(mux))
func Recoverer(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() { defer func() {
if rval := recover(); rval != nil { if rval := recover(); rval != nil {
debug.PrintStack() debug.PrintStack()
rvalStr := fmt.Sprint(rval) rvalStr := fmt.Sprint(rval)
packet := NewPacket(rvalStr, NewException(errors.New(rvalStr), NewStacktrace(2, 3, nil)), NewHttp(r)) var packet *Packet
if err, ok := rval.(error); ok {
packet = NewPacket(rvalStr, NewException(errors.New(rvalStr), GetOrNewStacktrace(err, 2, 3, nil)), NewHttp(r))
} else {
packet = NewPacket(rvalStr, NewException(errors.New(rvalStr), NewStacktrace(2, 3, nil)), NewHttp(r))
}
Capture(packet, nil) Capture(packet, nil)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
}() }()
handler(w, r) handler.ServeHTTP(w, r)
} })
} }
...@@ -9,10 +9,13 @@ import ( ...@@ -9,10 +9,13 @@ import (
"bytes" "bytes"
"go/build" "go/build"
"io/ioutil" "io/ioutil"
"net/url"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"github.com/pkg/errors"
) )
// https://docs.getsentry.com/hosted/clientdev/interfaces/#failure-interfaces // https://docs.getsentry.com/hosted/clientdev/interfaces/#failure-interfaces
...@@ -49,6 +52,34 @@ type StacktraceFrame struct { ...@@ -49,6 +52,34 @@ type StacktraceFrame struct {
InApp bool `json:"in_app"` InApp bool `json:"in_app"`
} }
// Try to get stacktrace from err as an interface of github.com/pkg/errors, or else NewStacktrace()
func GetOrNewStacktrace(err error, skip int, context int, appPackagePrefixes []string) *Stacktrace {
stacktracer, errHasStacktrace := err.(interface {
StackTrace() errors.StackTrace
})
if errHasStacktrace {
var frames []*StacktraceFrame
for _, f := range stacktracer.StackTrace() {
pc := uintptr(f) - 1
fn := runtime.FuncForPC(pc)
var file string
var line int
if fn != nil {
file, line = fn.FileLine(pc)
} else {
file = "unknown"
}
frame := NewStacktraceFrame(pc, file, line, context, appPackagePrefixes)
if frame != nil {
frames = append([]*StacktraceFrame{frame}, frames...)
}
}
return &Stacktrace{Frames: frames}
} else {
return NewStacktrace(skip+1, context, appPackagePrefixes)
}
}
// Intialize and populate a new stacktrace, skipping skip frames. // Intialize and populate a new stacktrace, skipping skip frames.
// //
// context is the number of surrounding lines that should be included for context. // context is the number of surrounding lines that should be included for context.
...@@ -113,7 +144,7 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr ...@@ -113,7 +144,7 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr
} }
if context > 0 { if context > 0 {
contextLines, lineIdx := fileContext(file, line, context) contextLines, lineIdx := sourceCodeLoader.Load(file, line, context)
if len(contextLines) > 0 { if len(contextLines) > 0 {
for i, line := range contextLines { for i, line := range contextLines {
switch { switch {
...@@ -127,7 +158,7 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr ...@@ -127,7 +158,7 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr
} }
} }
} else if context == -1 { } else if context == -1 {
contextLine, _ := fileContext(file, line, 0) contextLine, _ := sourceCodeLoader.Load(file, line, 0)
if len(contextLine) > 0 { if len(contextLine) > 0 {
frame.ContextLine = string(contextLine[0]) frame.ContextLine = string(contextLine[0])
} }
...@@ -136,40 +167,72 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr ...@@ -136,40 +167,72 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr
} }
// Retrieve the name of the package and function containing the PC. // Retrieve the name of the package and function containing the PC.
func functionName(pc uintptr) (pack string, name string) { func functionName(pc uintptr) (string, string) {
fn := runtime.FuncForPC(pc) fn := runtime.FuncForPC(pc)
if fn == nil { if fn == nil {
return return "", ""
} }
name = fn.Name()
// We get this: return splitFunctionName(fn.Name())
// runtime/debug.*T·ptrmethod
// and want this:
// pack = runtime/debug
// name = *T.ptrmethod
if idx := strings.LastIndex(name, "."); idx != -1 {
pack = name[:idx]
name = name[idx+1:]
}
name = strings.Replace(name, "·", ".", -1)
return
} }
var fileCacheLock sync.Mutex func splitFunctionName(name string) (string, string) {
var fileCache = make(map[string][][]byte) var pack string
if pos := strings.LastIndex(name, "/"); pos != -1 {
pack = name[:pos+1]
name = name[pos+1:]
}
func fileContext(filename string, line, context int) ([][]byte, int) { if pos := strings.Index(name, "."); pos != -1 {
fileCacheLock.Lock() pack += name[:pos]
defer fileCacheLock.Unlock() name = name[pos+1:]
lines, ok := fileCache[filename] }
if p, err := url.QueryUnescape(pack); err == nil {
pack = p
}
return pack, name
}
type SourceCodeLoader interface {
Load(filename string, line, context int) ([][]byte, int)
}
var sourceCodeLoader SourceCodeLoader = &fsLoader{cache: make(map[string][][]byte)}
func SetSourceCodeLoader(loader SourceCodeLoader) {
sourceCodeLoader = loader
}
type fsLoader struct {
mu sync.Mutex
cache map[string][][]byte
}
func (fs *fsLoader) Load(filename string, line, context int) ([][]byte, int) {
fs.mu.Lock()
defer fs.mu.Unlock()
lines, ok := fs.cache[filename]
if !ok { if !ok {
data, err := ioutil.ReadFile(filename) data, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
// cache errors as nil slice: code below handles it correctly
// otherwise when missing the source or running as a different user, we try
// reading the file on each error which is unnecessary
fs.cache[filename] = nil
return nil, 0 return nil, 0
} }
lines = bytes.Split(data, []byte{'\n'}) lines = bytes.Split(data, []byte{'\n'})
fileCache[filename] = lines fs.cache[filename] = lines
} }
if lines == nil {
// cached error from ReadFile: return no lines
return nil, 0
}
line-- // stack trace lines are 1-indexed line-- // stack trace lines are 1-indexed
start := line - context start := line - context
var idx int var idx int
......
...@@ -37,24 +37,9 @@ package proto ...@@ -37,24 +37,9 @@ package proto
import ( import (
"errors" "errors"
"fmt"
"reflect" "reflect"
) )
// RequiredNotSetError is an error type returned by either Marshal or Unmarshal.
// Marshal reports this when a required field is not initialized.
// Unmarshal reports this when a required field is missing from the wire data.
type RequiredNotSetError struct {
field string
}
func (e *RequiredNotSetError) Error() string {
if e.field == "" {
return fmt.Sprintf("proto: required field not set")
}
return fmt.Sprintf("proto: required field %q not set", e.field)
}
var ( var (
// errRepeatedHasNil is the error returned if Marshal is called with // errRepeatedHasNil is the error returned if Marshal is called with
// a struct with a repeated field containing a nil element. // a struct with a repeated field containing a nil element.
......
...@@ -265,7 +265,6 @@ package proto ...@@ -265,7 +265,6 @@ package proto
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log" "log"
"reflect" "reflect"
...@@ -274,7 +273,66 @@ import ( ...@@ -274,7 +273,66 @@ import (
"sync" "sync"
) )
var errInvalidUTF8 = errors.New("proto: invalid UTF-8 string") // RequiredNotSetError is an error type returned by either Marshal or Unmarshal.
// Marshal reports this when a required field is not initialized.
// Unmarshal reports this when a required field is missing from the wire data.
type RequiredNotSetError struct{ field string }
func (e *RequiredNotSetError) Error() string {
if e.field == "" {
return fmt.Sprintf("proto: required field not set")
}
return fmt.Sprintf("proto: required field %q not set", e.field)
}
func (e *RequiredNotSetError) RequiredNotSet() bool {
return true
}
type invalidUTF8Error struct{ field string }
func (e *invalidUTF8Error) Error() string {
if e.field == "" {
return "proto: invalid UTF-8 detected"
}
return fmt.Sprintf("proto: field %q contains invalid UTF-8", e.field)
}
func (e *invalidUTF8Error) InvalidUTF8() bool {
return true
}
// errInvalidUTF8 is a sentinel error to identify fields with invalid UTF-8.
// This error should not be exposed to the external API as such errors should
// be recreated with the field information.
var errInvalidUTF8 = &invalidUTF8Error{}
// isNonFatal reports whether the error is either a RequiredNotSet error
// or a InvalidUTF8 error.
func isNonFatal(err error) bool {
if re, ok := err.(interface{ RequiredNotSet() bool }); ok && re.RequiredNotSet() {
return true
}
if re, ok := err.(interface{ InvalidUTF8() bool }); ok && re.InvalidUTF8() {
return true
}
return false
}
type nonFatal struct{ E error }
// Merge merges err into nf and reports whether it was successful.
// Otherwise it returns false for any fatal non-nil errors.
func (nf *nonFatal) Merge(err error) (ok bool) {
if err == nil {
return true // not an error
}
if !isNonFatal(err) {
return false // fatal error
}
if nf.E == nil {
nf.E = err // store first instance of non-fatal error
}
return true
}
// Message is implemented by generated protocol buffer messages. // Message is implemented by generated protocol buffer messages.
type Message interface { type Message interface {
......
...@@ -231,7 +231,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte ...@@ -231,7 +231,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
return b, err return b, err
} }
var err, errreq error var err, errLater error
// The old marshaler encodes extensions at beginning. // The old marshaler encodes extensions at beginning.
if u.extensions.IsValid() { if u.extensions.IsValid() {
e := ptr.offset(u.extensions).toExtensions() e := ptr.offset(u.extensions).toExtensions()
...@@ -252,11 +252,13 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte ...@@ -252,11 +252,13 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
} }
} }
for _, f := range u.fields { for _, f := range u.fields {
if f.required && errreq == nil { if f.required {
if ptr.offset(f.field).getPointer().isNil() { if ptr.offset(f.field).getPointer().isNil() {
// Required field is not set. // Required field is not set.
// We record the error but keep going, to give a complete marshaling. // We record the error but keep going, to give a complete marshaling.
errreq = &RequiredNotSetError{f.name} if errLater == nil {
errLater = &RequiredNotSetError{f.name}
}
continue continue
} }
} }
...@@ -269,8 +271,8 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte ...@@ -269,8 +271,8 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
if err1, ok := err.(*RequiredNotSetError); ok { if err1, ok := err.(*RequiredNotSetError); ok {
// Required field in submessage is not set. // Required field in submessage is not set.
// We record the error but keep going, to give a complete marshaling. // We record the error but keep going, to give a complete marshaling.
if errreq == nil { if errLater == nil {
errreq = &RequiredNotSetError{f.name + "." + err1.field} errLater = &RequiredNotSetError{f.name + "." + err1.field}
} }
continue continue
} }
...@@ -278,8 +280,11 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte ...@@ -278,8 +280,11 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
err = errors.New("proto: repeated field " + f.name + " has nil element") err = errors.New("proto: repeated field " + f.name + " has nil element")
} }
if err == errInvalidUTF8 { if err == errInvalidUTF8 {
fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name if errLater == nil {
err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName) fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
errLater = &invalidUTF8Error{fullName}
}
continue
} }
return b, err return b, err
} }
...@@ -288,7 +293,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte ...@@ -288,7 +293,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
s := *ptr.offset(u.unrecognized).toBytes() s := *ptr.offset(u.unrecognized).toBytes()
b = append(b, s...) b = append(b, s...)
} }
return b, errreq return b, errLater
} }
// computeMarshalInfo initializes the marshal info. // computeMarshalInfo initializes the marshal info.
...@@ -2038,52 +2043,68 @@ func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, e ...@@ -2038,52 +2043,68 @@ func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, e
return b, nil return b, nil
} }
func appendUTF8StringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { func appendUTF8StringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
var invalidUTF8 bool
v := *ptr.toString() v := *ptr.toString()
if !utf8.ValidString(v) { if !utf8.ValidString(v) {
return nil, errInvalidUTF8 invalidUTF8 = true
} }
b = appendVarint(b, wiretag) b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v))) b = appendVarint(b, uint64(len(v)))
b = append(b, v...) b = append(b, v...)
if invalidUTF8 {
return b, errInvalidUTF8
}
return b, nil return b, nil
} }
func appendUTF8StringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { func appendUTF8StringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
var invalidUTF8 bool
v := *ptr.toString() v := *ptr.toString()
if v == "" { if v == "" {
return b, nil return b, nil
} }
if !utf8.ValidString(v) { if !utf8.ValidString(v) {
return nil, errInvalidUTF8 invalidUTF8 = true
} }
b = appendVarint(b, wiretag) b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v))) b = appendVarint(b, uint64(len(v)))
b = append(b, v...) b = append(b, v...)
if invalidUTF8 {
return b, errInvalidUTF8
}
return b, nil return b, nil
} }
func appendUTF8StringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { func appendUTF8StringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
var invalidUTF8 bool
p := *ptr.toStringPtr() p := *ptr.toStringPtr()
if p == nil { if p == nil {
return b, nil return b, nil
} }
v := *p v := *p
if !utf8.ValidString(v) { if !utf8.ValidString(v) {
return nil, errInvalidUTF8 invalidUTF8 = true
} }
b = appendVarint(b, wiretag) b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v))) b = appendVarint(b, uint64(len(v)))
b = append(b, v...) b = append(b, v...)
if invalidUTF8 {
return b, errInvalidUTF8
}
return b, nil return b, nil
} }
func appendUTF8StringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { func appendUTF8StringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
var invalidUTF8 bool
s := *ptr.toStringSlice() s := *ptr.toStringSlice()
for _, v := range s { for _, v := range s {
if !utf8.ValidString(v) { if !utf8.ValidString(v) {
return nil, errInvalidUTF8 invalidUTF8 = true
} }
b = appendVarint(b, wiretag) b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v))) b = appendVarint(b, uint64(len(v)))
b = append(b, v...) b = append(b, v...)
} }
if invalidUTF8 {
return b, errInvalidUTF8
}
return b, nil return b, nil
} }
func appendBytes(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { func appendBytes(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
...@@ -2162,7 +2183,8 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) { ...@@ -2162,7 +2183,8 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
}, },
func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) { func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
s := ptr.getPointerSlice() s := ptr.getPointerSlice()
var err, errreq error var err error
var nerr nonFatal
for _, v := range s { for _, v := range s {
if v.isNil() { if v.isNil() {
return b, errRepeatedHasNil return b, errRepeatedHasNil
...@@ -2170,22 +2192,14 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) { ...@@ -2170,22 +2192,14 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
b = appendVarint(b, wiretag) // start group b = appendVarint(b, wiretag) // start group
b, err = u.marshal(b, v, deterministic) b, err = u.marshal(b, v, deterministic)
b = appendVarint(b, wiretag+(WireEndGroup-WireStartGroup)) // end group b = appendVarint(b, wiretag+(WireEndGroup-WireStartGroup)) // end group
if err != nil { if !nerr.Merge(err) {
if _, ok := err.(*RequiredNotSetError); ok {
// Required field in submessage is not set.
// We record the error but keep going, to give a complete marshaling.
if errreq == nil {
errreq = err
}
continue
}
if err == ErrNil { if err == ErrNil {
err = errRepeatedHasNil err = errRepeatedHasNil
} }
return b, err return b, err
} }
} }
return b, errreq return b, nerr.E
} }
} }
...@@ -2229,7 +2243,8 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) { ...@@ -2229,7 +2243,8 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
}, },
func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) { func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
s := ptr.getPointerSlice() s := ptr.getPointerSlice()
var err, errreq error var err error
var nerr nonFatal
for _, v := range s { for _, v := range s {
if v.isNil() { if v.isNil() {
return b, errRepeatedHasNil return b, errRepeatedHasNil
...@@ -2239,22 +2254,14 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) { ...@@ -2239,22 +2254,14 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
b = appendVarint(b, uint64(siz)) b = appendVarint(b, uint64(siz))
b, err = u.marshal(b, v, deterministic) b, err = u.marshal(b, v, deterministic)
if err != nil { if !nerr.Merge(err) {
if _, ok := err.(*RequiredNotSetError); ok {
// Required field in submessage is not set.
// We record the error but keep going, to give a complete marshaling.
if errreq == nil {
errreq = err
}
continue
}
if err == ErrNil { if err == ErrNil {
err = errRepeatedHasNil err = errRepeatedHasNil
} }
return b, err return b, err
} }
} }
return b, errreq return b, nerr.E
} }
} }
...@@ -2317,6 +2324,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) { ...@@ -2317,6 +2324,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
if len(keys) > 1 && deterministic { if len(keys) > 1 && deterministic {
sort.Sort(mapKeys(keys)) sort.Sort(mapKeys(keys))
} }
var nerr nonFatal
for _, k := range keys { for _, k := range keys {
ki := k.Interface() ki := k.Interface()
vi := m.MapIndex(k).Interface() vi := m.MapIndex(k).Interface()
...@@ -2326,15 +2335,15 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) { ...@@ -2326,15 +2335,15 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1) siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
b = appendVarint(b, uint64(siz)) b = appendVarint(b, uint64(siz))
b, err = keyMarshaler(b, kaddr, keyWireTag, deterministic) b, err = keyMarshaler(b, kaddr, keyWireTag, deterministic)
if err != nil { if !nerr.Merge(err) {
return b, err return b, err
} }
b, err = valMarshaler(b, vaddr, valWireTag, deterministic) b, err = valMarshaler(b, vaddr, valWireTag, deterministic)
if err != nil && err != ErrNil { // allow nil value in map if err != ErrNil && !nerr.Merge(err) { // allow nil value in map
return b, err return b, err
} }
} }
return b, nil return b, nerr.E
} }
} }
...@@ -2407,6 +2416,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de ...@@ -2407,6 +2416,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
defer mu.Unlock() defer mu.Unlock()
var err error var err error
var nerr nonFatal
// Fast-path for common cases: zero or one extensions. // Fast-path for common cases: zero or one extensions.
// Don't bother sorting the keys. // Don't bother sorting the keys.
...@@ -2426,11 +2436,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de ...@@ -2426,11 +2436,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
v := e.value v := e.value
p := toAddrPointer(&v, ei.isptr) p := toAddrPointer(&v, ei.isptr)
b, err = ei.marshaler(b, p, ei.wiretag, deterministic) b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
if err != nil { if !nerr.Merge(err) {
return b, err return b, err
} }
} }
return b, nil return b, nerr.E
} }
// Sort the keys to provide a deterministic encoding. // Sort the keys to provide a deterministic encoding.
...@@ -2457,11 +2467,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de ...@@ -2457,11 +2467,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
v := e.value v := e.value
p := toAddrPointer(&v, ei.isptr) p := toAddrPointer(&v, ei.isptr)
b, err = ei.marshaler(b, p, ei.wiretag, deterministic) b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
if err != nil { if !nerr.Merge(err) {
return b, err return b, err
} }
} }
return b, nil return b, nerr.E
} }
// message set format is: // message set format is:
...@@ -2518,6 +2528,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de ...@@ -2518,6 +2528,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
defer mu.Unlock() defer mu.Unlock()
var err error var err error
var nerr nonFatal
// Fast-path for common cases: zero or one extensions. // Fast-path for common cases: zero or one extensions.
// Don't bother sorting the keys. // Don't bother sorting the keys.
...@@ -2544,12 +2555,12 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de ...@@ -2544,12 +2555,12 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
v := e.value v := e.value
p := toAddrPointer(&v, ei.isptr) p := toAddrPointer(&v, ei.isptr)
b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
if err != nil { if !nerr.Merge(err) {
return b, err return b, err
} }
b = append(b, 1<<3|WireEndGroup) b = append(b, 1<<3|WireEndGroup)
} }
return b, nil return b, nerr.E
} }
// Sort the keys to provide a deterministic encoding. // Sort the keys to provide a deterministic encoding.
...@@ -2583,11 +2594,11 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de ...@@ -2583,11 +2594,11 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
p := toAddrPointer(&v, ei.isptr) p := toAddrPointer(&v, ei.isptr)
b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
b = append(b, 1<<3|WireEndGroup) b = append(b, 1<<3|WireEndGroup)
if err != nil { if !nerr.Merge(err) {
return b, err return b, err
} }
} }
return b, nil return b, nerr.E
} }
// sizeV1Extensions computes the size of encoded data for a V1-API extension field. // sizeV1Extensions computes the size of encoded data for a V1-API extension field.
...@@ -2630,6 +2641,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ ...@@ -2630,6 +2641,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
sort.Ints(keys) sort.Ints(keys)
var err error var err error
var nerr nonFatal
for _, k := range keys { for _, k := range keys {
e := m[int32(k)] e := m[int32(k)]
if e.value == nil || e.desc == nil { if e.value == nil || e.desc == nil {
...@@ -2646,11 +2658,11 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ ...@@ -2646,11 +2658,11 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
v := e.value v := e.value
p := toAddrPointer(&v, ei.isptr) p := toAddrPointer(&v, ei.isptr)
b, err = ei.marshaler(b, p, ei.wiretag, deterministic) b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
if err != nil { if !nerr.Merge(err) {
return b, err return b, err
} }
} }
return b, nil return b, nerr.E
} }
// newMarshaler is the interface representing objects that can marshal themselves. // newMarshaler is the interface representing objects that can marshal themselves.
...@@ -2715,11 +2727,6 @@ func Marshal(pb Message) ([]byte, error) { ...@@ -2715,11 +2727,6 @@ func Marshal(pb Message) ([]byte, error) {
// a Buffer for most applications. // a Buffer for most applications.
func (p *Buffer) Marshal(pb Message) error { func (p *Buffer) Marshal(pb Message) error {
var err error var err error
if p.deterministic {
if _, ok := pb.(Marshaler); ok {
return fmt.Errorf("proto: deterministic not supported by the Marshal method of %T", pb)
}
}
if m, ok := pb.(newMarshaler); ok { if m, ok := pb.(newMarshaler); ok {
siz := m.XXX_Size() siz := m.XXX_Size()
p.grow(siz) // make sure buf has enough capacity p.grow(siz) // make sure buf has enough capacity
......
...@@ -138,8 +138,8 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { ...@@ -138,8 +138,8 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
if u.isMessageSet { if u.isMessageSet {
return UnmarshalMessageSet(b, m.offset(u.extensions).toExtensions()) return UnmarshalMessageSet(b, m.offset(u.extensions).toExtensions())
} }
var reqMask uint64 // bitmask of required fields we've seen. var reqMask uint64 // bitmask of required fields we've seen.
var rnse *RequiredNotSetError // an instance of a RequiredNotSetError returned by a submessage. var errLater error
for len(b) > 0 { for len(b) > 0 {
// Read tag and wire type. // Read tag and wire type.
// Special case 1 and 2 byte varints. // Special case 1 and 2 byte varints.
...@@ -178,14 +178,19 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { ...@@ -178,14 +178,19 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
if r, ok := err.(*RequiredNotSetError); ok { if r, ok := err.(*RequiredNotSetError); ok {
// Remember this error, but keep parsing. We need to produce // Remember this error, but keep parsing. We need to produce
// a full parse even if a required field is missing. // a full parse even if a required field is missing.
rnse = r if errLater == nil {
errLater = r
}
reqMask |= f.reqMask reqMask |= f.reqMask
continue continue
} }
if err != errInternalBadWireType { if err != errInternalBadWireType {
if err == errInvalidUTF8 { if err == errInvalidUTF8 {
fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name if errLater == nil {
err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName) fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
errLater = &invalidUTF8Error{fullName}
}
continue
} }
return err return err
} }
...@@ -245,20 +250,16 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { ...@@ -245,20 +250,16 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
emap[int32(tag)] = e emap[int32(tag)] = e
} }
} }
if rnse != nil { if reqMask != u.reqMask && errLater == nil {
// A required field of a submessage/group is missing. Return that error.
return rnse
}
if reqMask != u.reqMask {
// A required field of this message is missing. // A required field of this message is missing.
for _, n := range u.reqFields { for _, n := range u.reqFields {
if reqMask&1 == 0 { if reqMask&1 == 0 {
return &RequiredNotSetError{n} errLater = &RequiredNotSetError{n}
} }
reqMask >>= 1 reqMask >>= 1
} }
} }
return nil return errLater
} }
// computeUnmarshalInfo fills in u with information for use // computeUnmarshalInfo fills in u with information for use
...@@ -1529,10 +1530,10 @@ func unmarshalUTF8StringValue(b []byte, f pointer, w int) ([]byte, error) { ...@@ -1529,10 +1530,10 @@ func unmarshalUTF8StringValue(b []byte, f pointer, w int) ([]byte, error) {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
v := string(b[:x]) v := string(b[:x])
*f.toString() = v
if !utf8.ValidString(v) { if !utf8.ValidString(v) {
return nil, errInvalidUTF8 return b[x:], errInvalidUTF8
} }
*f.toString() = v
return b[x:], nil return b[x:], nil
} }
...@@ -1549,10 +1550,10 @@ func unmarshalUTF8StringPtr(b []byte, f pointer, w int) ([]byte, error) { ...@@ -1549,10 +1550,10 @@ func unmarshalUTF8StringPtr(b []byte, f pointer, w int) ([]byte, error) {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
v := string(b[:x]) v := string(b[:x])
*f.toStringPtr() = &v
if !utf8.ValidString(v) { if !utf8.ValidString(v) {
return nil, errInvalidUTF8 return b[x:], errInvalidUTF8
} }
*f.toStringPtr() = &v
return b[x:], nil return b[x:], nil
} }
...@@ -1569,11 +1570,11 @@ func unmarshalUTF8StringSlice(b []byte, f pointer, w int) ([]byte, error) { ...@@ -1569,11 +1570,11 @@ func unmarshalUTF8StringSlice(b []byte, f pointer, w int) ([]byte, error) {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
v := string(b[:x]) v := string(b[:x])
if !utf8.ValidString(v) {
return nil, errInvalidUTF8
}
s := f.toStringSlice() s := f.toStringSlice()
*s = append(*s, v) *s = append(*s, v)
if !utf8.ValidString(v) {
return b[x:], errInvalidUTF8
}
return b[x:], nil return b[x:], nil
} }
...@@ -1755,6 +1756,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { ...@@ -1755,6 +1756,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
// Maps will be somewhat slow. Oh well. // Maps will be somewhat slow. Oh well.
// Read key and value from data. // Read key and value from data.
var nerr nonFatal
k := reflect.New(kt) k := reflect.New(kt)
v := reflect.New(vt) v := reflect.New(vt)
for len(b) > 0 { for len(b) > 0 {
...@@ -1775,7 +1777,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { ...@@ -1775,7 +1777,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
err = errInternalBadWireType // skip unknown tag err = errInternalBadWireType // skip unknown tag
} }
if err == nil { if nerr.Merge(err) {
continue continue
} }
if err != errInternalBadWireType { if err != errInternalBadWireType {
...@@ -1798,7 +1800,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { ...@@ -1798,7 +1800,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
// Insert into map. // Insert into map.
m.SetMapIndex(k.Elem(), v.Elem()) m.SetMapIndex(k.Elem(), v.Elem())
return r, nil return r, nerr.E
} }
} }
...@@ -1824,15 +1826,16 @@ func makeUnmarshalOneof(typ, ityp reflect.Type, unmarshal unmarshaler) unmarshal ...@@ -1824,15 +1826,16 @@ func makeUnmarshalOneof(typ, ityp reflect.Type, unmarshal unmarshaler) unmarshal
// Unmarshal data into holder. // Unmarshal data into holder.
// We unmarshal into the first field of the holder object. // We unmarshal into the first field of the holder object.
var err error var err error
var nerr nonFatal
b, err = unmarshal(b, valToPointer(v).offset(field0), w) b, err = unmarshal(b, valToPointer(v).offset(field0), w)
if err != nil { if !nerr.Merge(err) {
return nil, err return nil, err
} }
// Write pointer to holder into target field. // Write pointer to holder into target field.
f.asPointerTo(ityp).Elem().Set(v) f.asPointerTo(ityp).Elem().Set(v)
return b, nil return b, nerr.E
} }
} }
......
...@@ -130,10 +130,12 @@ func UnmarshalAny(any *any.Any, pb proto.Message) error { ...@@ -130,10 +130,12 @@ func UnmarshalAny(any *any.Any, pb proto.Message) error {
// Is returns true if any value contains a given message type. // Is returns true if any value contains a given message type.
func Is(any *any.Any, pb proto.Message) bool { func Is(any *any.Any, pb proto.Message) bool {
aname, err := AnyMessageName(any) // The following is equivalent to AnyMessageName(any) == proto.MessageName(pb),
if err != nil { // but it avoids scanning TypeUrl for the slash.
if any == nil {
return false return false
} }
name := proto.MessageName(pb)
return aname == proto.MessageName(pb) prefix := len(any.TypeUrl) - len(name)
return prefix >= 1 && any.TypeUrl[prefix-1] == '/' && any.TypeUrl[prefix:] == name
} }
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: google/protobuf/any.proto // source: google/protobuf/any.proto
/* package any // import "github.com/golang/protobuf/ptypes/any"
Package any is a generated protocol buffer package.
It is generated from these files:
google/protobuf/any.proto
It has these top-level messages:
Any
*/
package any
import proto "github.com/golang/protobuf/proto" import proto "github.com/golang/protobuf/proto"
import fmt "fmt" import fmt "fmt"
...@@ -130,16 +121,38 @@ type Any struct { ...@@ -130,16 +121,38 @@ type Any struct {
// Schemes other than `http`, `https` (or the empty scheme) might be // Schemes other than `http`, `https` (or the empty scheme) might be
// used with implementation specific semantics. // used with implementation specific semantics.
// //
TypeUrl string `protobuf:"bytes,1,opt,name=type_url,json=typeUrl" json:"type_url,omitempty"` TypeUrl string `protobuf:"bytes,1,opt,name=type_url,json=typeUrl,proto3" json:"type_url,omitempty"`
// Must be a valid serialized protocol buffer of the above specified type. // Must be a valid serialized protocol buffer of the above specified type.
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Any) Reset() { *m = Any{} }
func (m *Any) String() string { return proto.CompactTextString(m) }
func (*Any) ProtoMessage() {}
func (*Any) Descriptor() ([]byte, []int) {
return fileDescriptor_any_744b9ca530f228db, []int{0}
}
func (*Any) XXX_WellKnownType() string { return "Any" }
func (m *Any) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Any.Unmarshal(m, b)
}
func (m *Any) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Any.Marshal(b, m, deterministic)
}
func (dst *Any) XXX_Merge(src proto.Message) {
xxx_messageInfo_Any.Merge(dst, src)
}
func (m *Any) XXX_Size() int {
return xxx_messageInfo_Any.Size(m)
}
func (m *Any) XXX_DiscardUnknown() {
xxx_messageInfo_Any.DiscardUnknown(m)
} }
func (m *Any) Reset() { *m = Any{} } var xxx_messageInfo_Any proto.InternalMessageInfo
func (m *Any) String() string { return proto.CompactTextString(m) }
func (*Any) ProtoMessage() {}
func (*Any) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (*Any) XXX_WellKnownType() string { return "Any" }
func (m *Any) GetTypeUrl() string { func (m *Any) GetTypeUrl() string {
if m != nil { if m != nil {
...@@ -159,9 +172,9 @@ func init() { ...@@ -159,9 +172,9 @@ func init() {
proto.RegisterType((*Any)(nil), "google.protobuf.Any") proto.RegisterType((*Any)(nil), "google.protobuf.Any")
} }
func init() { proto.RegisterFile("google/protobuf/any.proto", fileDescriptor0) } func init() { proto.RegisterFile("google/protobuf/any.proto", fileDescriptor_any_744b9ca530f228db) }
var fileDescriptor0 = []byte{ var fileDescriptor_any_744b9ca530f228db = []byte{
// 185 bytes of a gzipped FileDescriptorProto // 185 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4c, 0xcf, 0xcf, 0x4f, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4c, 0xcf, 0xcf, 0x4f,
0xcf, 0x49, 0xd5, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0x2a, 0x4d, 0xd3, 0x4f, 0xcc, 0xab, 0xd4, 0xcf, 0x49, 0xd5, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0x2a, 0x4d, 0xd3, 0x4f, 0xcc, 0xab, 0xd4,
......
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: google/protobuf/duration.proto // source: google/protobuf/duration.proto
/* package duration // import "github.com/golang/protobuf/ptypes/duration"
Package duration is a generated protocol buffer package.
It is generated from these files:
google/protobuf/duration.proto
It has these top-level messages:
Duration
*/
package duration
import proto "github.com/golang/protobuf/proto" import proto "github.com/golang/protobuf/proto"
import fmt "fmt" import fmt "fmt"
...@@ -91,21 +82,43 @@ type Duration struct { ...@@ -91,21 +82,43 @@ type Duration struct {
// Signed seconds of the span of time. Must be from -315,576,000,000 // Signed seconds of the span of time. Must be from -315,576,000,000
// to +315,576,000,000 inclusive. Note: these bounds are computed from: // to +315,576,000,000 inclusive. Note: these bounds are computed from:
// 60 sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years // 60 sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
Seconds int64 `protobuf:"varint,1,opt,name=seconds" json:"seconds,omitempty"` Seconds int64 `protobuf:"varint,1,opt,name=seconds,proto3" json:"seconds,omitempty"`
// Signed fractions of a second at nanosecond resolution of the span // Signed fractions of a second at nanosecond resolution of the span
// of time. Durations less than one second are represented with a 0 // of time. Durations less than one second are represented with a 0
// `seconds` field and a positive or negative `nanos` field. For durations // `seconds` field and a positive or negative `nanos` field. For durations
// of one second or more, a non-zero value for the `nanos` field must be // of one second or more, a non-zero value for the `nanos` field must be
// of the same sign as the `seconds` field. Must be from -999,999,999 // of the same sign as the `seconds` field. Must be from -999,999,999
// to +999,999,999 inclusive. // to +999,999,999 inclusive.
Nanos int32 `protobuf:"varint,2,opt,name=nanos" json:"nanos,omitempty"` Nanos int32 `protobuf:"varint,2,opt,name=nanos,proto3" json:"nanos,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
} }
func (m *Duration) Reset() { *m = Duration{} } func (m *Duration) Reset() { *m = Duration{} }
func (m *Duration) String() string { return proto.CompactTextString(m) } func (m *Duration) String() string { return proto.CompactTextString(m) }
func (*Duration) ProtoMessage() {} func (*Duration) ProtoMessage() {}
func (*Duration) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } func (*Duration) Descriptor() ([]byte, []int) {
func (*Duration) XXX_WellKnownType() string { return "Duration" } return fileDescriptor_duration_e7d612259e3f0613, []int{0}
}
func (*Duration) XXX_WellKnownType() string { return "Duration" }
func (m *Duration) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Duration.Unmarshal(m, b)
}
func (m *Duration) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Duration.Marshal(b, m, deterministic)
}
func (dst *Duration) XXX_Merge(src proto.Message) {
xxx_messageInfo_Duration.Merge(dst, src)
}
func (m *Duration) XXX_Size() int {
return xxx_messageInfo_Duration.Size(m)
}
func (m *Duration) XXX_DiscardUnknown() {
xxx_messageInfo_Duration.DiscardUnknown(m)
}
var xxx_messageInfo_Duration proto.InternalMessageInfo
func (m *Duration) GetSeconds() int64 { func (m *Duration) GetSeconds() int64 {
if m != nil { if m != nil {
...@@ -125,9 +138,11 @@ func init() { ...@@ -125,9 +138,11 @@ func init() {
proto.RegisterType((*Duration)(nil), "google.protobuf.Duration") proto.RegisterType((*Duration)(nil), "google.protobuf.Duration")
} }
func init() { proto.RegisterFile("google/protobuf/duration.proto", fileDescriptor0) } func init() {
proto.RegisterFile("google/protobuf/duration.proto", fileDescriptor_duration_e7d612259e3f0613)
}
var fileDescriptor0 = []byte{ var fileDescriptor_duration_e7d612259e3f0613 = []byte{
// 190 bytes of a gzipped FileDescriptorProto // 190 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4b, 0xcf, 0xcf, 0x4f, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4b, 0xcf, 0xcf, 0x4f,
0xcf, 0x49, 0xd5, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0x2a, 0x4d, 0xd3, 0x4f, 0x29, 0x2d, 0x4a, 0xcf, 0x49, 0xd5, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0x2a, 0x4d, 0xd3, 0x4f, 0x29, 0x2d, 0x4a,
......
#!/bin/bash -e
#
# This script fetches and rebuilds the "well-known types" protocol buffers.
# To run this you will need protoc and goprotobuf installed;
# see https://github.com/golang/protobuf for instructions.
# You also need Go and Git installed.
PKG=github.com/golang/protobuf/ptypes
UPSTREAM=https://github.com/google/protobuf
UPSTREAM_SUBDIR=src/google/protobuf
PROTO_FILES=(any duration empty struct timestamp wrappers)
function die() {
echo 1>&2 $*
exit 1
}
# Sanity check that the right tools are accessible.
for tool in go git protoc protoc-gen-go; do
q=$(which $tool) || die "didn't find $tool"
echo 1>&2 "$tool: $q"
done
tmpdir=$(mktemp -d -t regen-wkt.XXXXXX)
trap 'rm -rf $tmpdir' EXIT
echo -n 1>&2 "finding package dir... "
pkgdir=$(go list -f '{{.Dir}}' $PKG)
echo 1>&2 $pkgdir
base=$(echo $pkgdir | sed "s,/$PKG\$,,")
echo 1>&2 "base: $base"
cd "$base"
echo 1>&2 "fetching latest protos... "
git clone -q $UPSTREAM $tmpdir
for file in ${PROTO_FILES[@]}; do
echo 1>&2 "* $file"
protoc --go_out=. -I$tmpdir/src $tmpdir/src/google/protobuf/$file.proto || die
cp $tmpdir/src/google/protobuf/$file.proto $PKG/$file
done
echo 1>&2 "All OK"
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: google/protobuf/timestamp.proto // source: google/protobuf/timestamp.proto
/* package timestamp // import "github.com/golang/protobuf/ptypes/timestamp"
Package timestamp is a generated protocol buffer package.
It is generated from these files:
google/protobuf/timestamp.proto
It has these top-level messages:
Timestamp
*/
package timestamp
import proto "github.com/golang/protobuf/proto" import proto "github.com/golang/protobuf/proto"
import fmt "fmt" import fmt "fmt"
...@@ -101,7 +92,7 @@ const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package ...@@ -101,7 +92,7 @@ const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// to this format using [`strftime`](https://docs.python.org/2/library/time.html#time.strftime) // to this format using [`strftime`](https://docs.python.org/2/library/time.html#time.strftime)
// with the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one // with the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one
// can use the Joda Time's [`ISODateTimeFormat.dateTime()`]( // can use the Joda Time's [`ISODateTimeFormat.dateTime()`](
// http://joda-time.sourceforge.net/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime()) // http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime--)
// to obtain a formatter capable of generating timestamps in this format. // to obtain a formatter capable of generating timestamps in this format.
// //
// //
...@@ -109,19 +100,41 @@ type Timestamp struct { ...@@ -109,19 +100,41 @@ type Timestamp struct {
// Represents seconds of UTC time since Unix epoch // Represents seconds of UTC time since Unix epoch
// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to // 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
// 9999-12-31T23:59:59Z inclusive. // 9999-12-31T23:59:59Z inclusive.
Seconds int64 `protobuf:"varint,1,opt,name=seconds" json:"seconds,omitempty"` Seconds int64 `protobuf:"varint,1,opt,name=seconds,proto3" json:"seconds,omitempty"`
// Non-negative fractions of a second at nanosecond resolution. Negative // Non-negative fractions of a second at nanosecond resolution. Negative
// second values with fractions must still have non-negative nanos values // second values with fractions must still have non-negative nanos values
// that count forward in time. Must be from 0 to 999,999,999 // that count forward in time. Must be from 0 to 999,999,999
// inclusive. // inclusive.
Nanos int32 `protobuf:"varint,2,opt,name=nanos" json:"nanos,omitempty"` Nanos int32 `protobuf:"varint,2,opt,name=nanos,proto3" json:"nanos,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
} }
func (m *Timestamp) Reset() { *m = Timestamp{} } func (m *Timestamp) Reset() { *m = Timestamp{} }
func (m *Timestamp) String() string { return proto.CompactTextString(m) } func (m *Timestamp) String() string { return proto.CompactTextString(m) }
func (*Timestamp) ProtoMessage() {} func (*Timestamp) ProtoMessage() {}
func (*Timestamp) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } func (*Timestamp) Descriptor() ([]byte, []int) {
func (*Timestamp) XXX_WellKnownType() string { return "Timestamp" } return fileDescriptor_timestamp_b826e8e5fba671a8, []int{0}
}
func (*Timestamp) XXX_WellKnownType() string { return "Timestamp" }
func (m *Timestamp) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Timestamp.Unmarshal(m, b)
}
func (m *Timestamp) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Timestamp.Marshal(b, m, deterministic)
}
func (dst *Timestamp) XXX_Merge(src proto.Message) {
xxx_messageInfo_Timestamp.Merge(dst, src)
}
func (m *Timestamp) XXX_Size() int {
return xxx_messageInfo_Timestamp.Size(m)
}
func (m *Timestamp) XXX_DiscardUnknown() {
xxx_messageInfo_Timestamp.DiscardUnknown(m)
}
var xxx_messageInfo_Timestamp proto.InternalMessageInfo
func (m *Timestamp) GetSeconds() int64 { func (m *Timestamp) GetSeconds() int64 {
if m != nil { if m != nil {
...@@ -141,9 +154,11 @@ func init() { ...@@ -141,9 +154,11 @@ func init() {
proto.RegisterType((*Timestamp)(nil), "google.protobuf.Timestamp") proto.RegisterType((*Timestamp)(nil), "google.protobuf.Timestamp")
} }
func init() { proto.RegisterFile("google/protobuf/timestamp.proto", fileDescriptor0) } func init() {
proto.RegisterFile("google/protobuf/timestamp.proto", fileDescriptor_timestamp_b826e8e5fba671a8)
}
var fileDescriptor0 = []byte{ var fileDescriptor_timestamp_b826e8e5fba671a8 = []byte{
// 191 bytes of a gzipped FileDescriptorProto // 191 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4f, 0xcf, 0xcf, 0x4f, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4f, 0xcf, 0xcf, 0x4f,
0xcf, 0x49, 0xd5, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0x2a, 0x4d, 0xd3, 0x2f, 0xc9, 0xcc, 0x4d, 0xcf, 0x49, 0xd5, 0x2f, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0x2a, 0x4d, 0xd3, 0x2f, 0xc9, 0xcc, 0x4d,
......
...@@ -114,7 +114,7 @@ option objc_class_prefix = "GPB"; ...@@ -114,7 +114,7 @@ option objc_class_prefix = "GPB";
// to this format using [`strftime`](https://docs.python.org/2/library/time.html#time.strftime) // to this format using [`strftime`](https://docs.python.org/2/library/time.html#time.strftime)
// with the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one // with the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one
// can use the Joda Time's [`ISODateTimeFormat.dateTime()`]( // can use the Joda Time's [`ISODateTimeFormat.dateTime()`](
// http://joda-time.sourceforge.net/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime()) // http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime--)
// to obtain a formatter capable of generating timestamps in this format. // to obtain a formatter capable of generating timestamps in this format.
// //
// //
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// License for the specific language governing permissions and limitations // License for the specific language governing permissions and limitations
// under the License. // under the License.
package internal package internal // import "github.com/gomodule/redigo/internal"
import ( import (
"strings" "strings"
......
...@@ -29,9 +29,12 @@ import ( ...@@ -29,9 +29,12 @@ import (
"time" "time"
) )
var (
_ ConnWithTimeout = (*conn)(nil)
)
// conn is the low-level implementation of Conn // conn is the low-level implementation of Conn
type conn struct { type conn struct {
// Shared // Shared
mu sync.Mutex mu sync.Mutex
pending int pending int
...@@ -73,10 +76,11 @@ type DialOption struct { ...@@ -73,10 +76,11 @@ type DialOption struct {
type dialOptions struct { type dialOptions struct {
readTimeout time.Duration readTimeout time.Duration
writeTimeout time.Duration writeTimeout time.Duration
dialer *net.Dialer
dial func(network, addr string) (net.Conn, error) dial func(network, addr string) (net.Conn, error)
db int db int
password string password string
dialTLS bool useTLS bool
skipVerify bool skipVerify bool
tlsConfig *tls.Config tlsConfig *tls.Config
} }
...@@ -95,17 +99,27 @@ func DialWriteTimeout(d time.Duration) DialOption { ...@@ -95,17 +99,27 @@ func DialWriteTimeout(d time.Duration) DialOption {
}} }}
} }
// DialConnectTimeout specifies the timeout for connecting to the Redis server. // DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
func DialConnectTimeout(d time.Duration) DialOption { func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d} do.dialer.Timeout = d
do.dial = dialer.Dial }}
}
// DialKeepAlive specifies the keep-alive period for TCP connections to the Redis server
// when no DialNetDial option is specified.
// If zero, keep-alives are not enabled. If no DialKeepAlive option is specified then
// the default of 5 minutes is used to ensure that half-closed TCP sessions are detected.
func DialKeepAlive(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.KeepAlive = d
}} }}
} }
// DialNetDial specifies a custom dial function for creating TCP // DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is // connections, otherwise a net.Dialer customized via the other options is used.
// used. DialNetDial overrides DialConnectTimeout. // DialNetDial overrides DialConnectTimeout and DialKeepAlive.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
do.dial = dial do.dial = dial
...@@ -128,38 +142,56 @@ func DialPassword(password string) DialOption { ...@@ -128,38 +142,56 @@ func DialPassword(password string) DialOption {
} }
// DialTLSConfig specifies the config to use when a TLS connection is dialed. // DialTLSConfig specifies the config to use when a TLS connection is dialed.
// Has no effect when not dialing a TLS connection. // Has no effect when not dialing a TLS connection.
func DialTLSConfig(c *tls.Config) DialOption { func DialTLSConfig(c *tls.Config) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
do.tlsConfig = c do.tlsConfig = c
}} }}
} }
// DialTLSSkipVerify to disable server name verification when connecting // DialTLSSkipVerify disables server name verification when connecting over
// over TLS. Has no effect when not dialing a TLS connection. // TLS. Has no effect when not dialing a TLS connection.
func DialTLSSkipVerify(skip bool) DialOption { func DialTLSSkipVerify(skip bool) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
do.skipVerify = skip do.skipVerify = skip
}} }}
} }
// DialUseTLS specifies whether TLS should be used when connecting to the
// server. This option is ignore by DialURL.
func DialUseTLS(useTLS bool) DialOption {
return DialOption{func(do *dialOptions) {
do.useTLS = useTLS
}}
}
// Dial connects to the Redis server at the given network and // Dial connects to the Redis server at the given network and
// address using the specified options. // address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) { func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{ do := dialOptions{
dial: net.Dial, dialer: &net.Dialer{
KeepAlive: time.Minute * 5,
},
} }
for _, option := range options { for _, option := range options {
option.f(&do) option.f(&do)
} }
if do.dial == nil {
do.dial = do.dialer.Dial
}
netConn, err := do.dial(network, address) netConn, err := do.dial(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if do.dialTLS { if do.useTLS {
tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify) var tlsConfig *tls.Config
if do.tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: do.skipVerify}
} else {
tlsConfig = cloneTLSConfig(do.tlsConfig)
}
if tlsConfig.ServerName == "" { if tlsConfig.ServerName == "" {
host, _, err := net.SplitHostPort(address) host, _, err := net.SplitHostPort(address)
if err != nil { if err != nil {
...@@ -202,10 +234,6 @@ func Dial(network, address string, options ...DialOption) (Conn, error) { ...@@ -202,10 +234,6 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
return c, nil return c, nil
} }
func dialTLS(do *dialOptions) {
do.dialTLS = true
}
var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`) var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
// DialURL connects to a Redis server at the given URL using the Redis // DialURL connects to a Redis server at the given URL using the Redis
...@@ -257,9 +285,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) { ...@@ -257,9 +285,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
} }
if u.Scheme == "rediss" { options = append(options, DialUseTLS(u.Scheme == "rediss"))
options = append([]DialOption{{dialTLS}}, options...)
}
return Dial("tcp", address, options...) return Dial("tcp", address, options...)
} }
...@@ -344,39 +370,55 @@ func (c *conn) writeFloat64(n float64) error { ...@@ -344,39 +370,55 @@ func (c *conn) writeFloat64(n float64) error {
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
} }
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { func (c *conn) writeCommand(cmd string, args []interface{}) error {
c.writeLen('*', 1+len(args)) c.writeLen('*', 1+len(args))
err = c.writeString(cmd) if err := c.writeString(cmd); err != nil {
return err
}
for _, arg := range args { for _, arg := range args {
if err != nil { if err := c.writeArg(arg, true); err != nil {
break return err
} }
switch arg := arg.(type) { }
case string: return nil
err = c.writeString(arg) }
case []byte:
err = c.writeBytes(arg) func (c *conn) writeArg(arg interface{}, argumentTypeOK bool) (err error) {
case int: switch arg := arg.(type) {
err = c.writeInt64(int64(arg)) case string:
case int64: return c.writeString(arg)
err = c.writeInt64(arg) case []byte:
case float64: return c.writeBytes(arg)
err = c.writeFloat64(arg) case int:
case bool: return c.writeInt64(int64(arg))
if arg { case int64:
err = c.writeString("1") return c.writeInt64(arg)
} else { case float64:
err = c.writeString("0") return c.writeFloat64(arg)
} case bool:
case nil: if arg {
err = c.writeString("") return c.writeString("1")
default: } else {
var buf bytes.Buffer return c.writeString("0")
fmt.Fprint(&buf, arg) }
err = c.writeBytes(buf.Bytes()) case nil:
return c.writeString("")
case Argument:
if argumentTypeOK {
return c.writeArg(arg.RedisArg(), false)
} }
// See comment in default clause below.
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
return c.writeBytes(buf.Bytes())
default:
// This default clause is intended to handle builtin numeric types.
// The function should return an error for other types, but this is not
// done for compatibility with previous versions of the package.
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
return c.writeBytes(buf.Bytes())
} }
return err
} }
type protocolError string type protocolError string
...@@ -538,10 +580,17 @@ func (c *conn) Flush() error { ...@@ -538,10 +580,17 @@ func (c *conn) Flush() error {
return nil return nil
} }
func (c *conn) Receive() (reply interface{}, err error) { func (c *conn) Receive() (interface{}, error) {
if c.readTimeout != 0 { return c.ReceiveWithTimeout(c.readTimeout)
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) }
func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
} }
c.conn.SetReadDeadline(deadline)
if reply, err = c.readReply(); err != nil { if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
...@@ -564,6 +613,10 @@ func (c *conn) Receive() (reply interface{}, err error) { ...@@ -564,6 +613,10 @@ func (c *conn) Receive() (reply interface{}, err error) {
} }
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return c.DoWithTimeout(c.readTimeout, cmd, args...)
}
func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock() c.mu.Lock()
pending := c.pending pending := c.pending
c.pending = 0 c.pending = 0
...@@ -587,9 +640,11 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { ...@@ -587,9 +640,11 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
if c.readTimeout != 0 { var deadline time.Time
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) if readTimeout != 0 {
deadline = time.Now().Add(readTimeout)
} }
c.conn.SetReadDeadline(deadline)
if cmd == "" { if cmd == "" {
reply := make([]interface{}, pending) reply := make([]interface{}, pending)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// Package redis is a client for the Redis database. // Package redis is a client for the Redis database.
// //
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more // The Redigo FAQ (https://github.com/gomodule/redigo/wiki/FAQ) contains more
// documentation about this package. // documentation about this package.
// //
// Connections // Connections
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
// //
// n, err := conn.Do("APPEND", "key", "value") // n, err := conn.Do("APPEND", "key", "value")
// //
// The Do method converts command arguments to binary strings for transmission // The Do method converts command arguments to bulk strings for transmission
// to the server as follows: // to the server as follows:
// //
// Go Type Conversion // Go Type Conversion
...@@ -48,7 +48,7 @@ ...@@ -48,7 +48,7 @@
// float64 strconv.FormatFloat(v, 'g', -1, 64) // float64 strconv.FormatFloat(v, 'g', -1, 64)
// bool true -> "1", false -> "0" // bool true -> "1", false -> "0"
// nil "" // nil ""
// all other types fmt.Print(v) // all other types fmt.Fprint(w, v)
// //
// Redis command reply types are represented using the following Go types: // Redis command reply types are represented using the following Go types:
// //
...@@ -174,4 +174,4 @@ ...@@ -174,4 +174,4 @@
// non-recoverable error such as a network error or protocol parsing error. If // non-recoverable error such as a network error or protocol parsing error. If
// Err() returns a non-nil value, then the connection is not usable and should // Err() returns a non-nil value, then the connection is not usable and should
// be closed. // be closed.
package redis package redis // import "github.com/gomodule/redigo/redis"
...@@ -4,11 +4,7 @@ package redis ...@@ -4,11 +4,7 @@ package redis
import "crypto/tls" import "crypto/tls"
// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case func cloneTLSConfig(cfg *tls.Config) *tls.Config {
func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
if cfg == nil {
return &tls.Config{InsecureSkipVerify: skipVerify}
}
return &tls.Config{ return &tls.Config{
Rand: cfg.Rand, Rand: cfg.Rand,
Time: cfg.Time, Time: cfg.Time,
......
// +build go1.7 // +build go1.7,!go1.8
package redis package redis
import "crypto/tls" import "crypto/tls"
// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case func cloneTLSConfig(cfg *tls.Config) *tls.Config {
func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
if cfg == nil {
return &tls.Config{InsecureSkipVerify: skipVerify}
}
return &tls.Config{ return &tls.Config{
Rand: cfg.Rand, Rand: cfg.Rand,
Time: cfg.Time, Time: cfg.Time,
......
// +build go1.8
package redis
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
return cfg.Clone()
}
...@@ -18,6 +18,11 @@ import ( ...@@ -18,6 +18,11 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"log" "log"
"time"
)
var (
_ ConnWithTimeout = (*loggingConn)(nil)
) )
// NewLoggingConn returns a logging wrapper around a connection. // NewLoggingConn returns a logging wrapper around a connection.
...@@ -104,6 +109,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, ...@@ -104,6 +109,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
return reply, err return reply, err
} }
func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
c.print("DoWithTimeout", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error { func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...) err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err) c.print("Send", commandName, args, nil, err)
...@@ -115,3 +126,9 @@ func (c *loggingConn) Receive() (interface{}, error) { ...@@ -115,3 +126,9 @@ func (c *loggingConn) Receive() (interface{}, error) {
c.print("Receive", "", nil, reply, err) c.print("Receive", "", nil, reply, err)
return reply, err return reply, err
} }
func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
reply, err := ReceiveWithTimeout(c.Conn, timeout)
c.print("ReceiveWithTimeout", "", nil, reply, err)
return reply, err
}
// Copyright 2012 Gary Burd // Copyright 2018 Gary Burd
// //
// Licensed under the Apache License, Version 2.0 (the "License"): you may // Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain // not use this file except in compliance with the License. You may obtain
...@@ -12,30 +12,24 @@ ...@@ -12,30 +12,24 @@
// License for the specific language governing permissions and limitations // License for the specific language governing permissions and limitations
// under the License. // under the License.
package redis // +build go1.7
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value when the connection is not usable.
Err() error
// Do sends a command to the server and returns the received reply. package redis
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server. import "context"
Flush() error
// Receive receives a single reply from the Redis server // GetContext gets a connection using the provided context.
Receive() (reply interface{}, err error) //
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Any expiration on the context
// will not affect the returned connection.
//
// If the function completes without error, then the application must close the
// returned connection.
func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
pc, err := p.get(ctx)
if err != nil {
return errorConn{err}, err
}
return &activeConn{p: p, pc: pc}, nil
} }
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
package redis package redis
import "errors" import (
"errors"
"time"
)
// Subscription represents a subscribe or unsubscribe notification. // Subscription represents a subscribe or unsubscribe notification.
type Subscription struct { type Subscription struct {
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe" // Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
Kind string Kind string
...@@ -31,23 +33,12 @@ type Subscription struct { ...@@ -31,23 +33,12 @@ type Subscription struct {
// Message represents a message notification. // Message represents a message notification.
type Message struct { type Message struct {
// The originating channel. // The originating channel.
Channel string Channel string
// The message data. // The matched pattern, if any
Data []byte
}
// PMessage represents a pmessage notification.
type PMessage struct {
// The matched pattern.
Pattern string Pattern string
// The originating channel.
Channel string
// The message data. // The message data.
Data []byte Data []byte
} }
...@@ -94,16 +85,29 @@ func (c PubSubConn) PUnsubscribe(channel ...interface{}) error { ...@@ -94,16 +85,29 @@ func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
} }
// Ping sends a PING to the server with the specified data. // Ping sends a PING to the server with the specified data.
//
// The connection must be subscribed to at least one channel or pattern when
// calling this method.
func (c PubSubConn) Ping(data string) error { func (c PubSubConn) Ping(data string) error {
c.Conn.Send("PING", data) c.Conn.Send("PING", data)
return c.Conn.Flush() return c.Conn.Flush()
} }
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong // Receive returns a pushed message as a Subscription, Message, Pong or error.
// or error. The return value is intended to be used directly in a type switch // The return value is intended to be used directly in a type switch as
// as illustrated in the PubSubConn example. // illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} { func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive()) return c.receiveInternal(c.Conn.Receive())
}
// ReceiveWithTimeout is like Receive, but it allows the application to
// override the connection's default timeout.
func (c PubSubConn) ReceiveWithTimeout(timeout time.Duration) interface{} {
return c.receiveInternal(ReceiveWithTimeout(c.Conn, timeout))
}
func (c PubSubConn) receiveInternal(replyArg interface{}, errArg error) interface{} {
reply, err := Values(replyArg, errArg)
if err != nil { if err != nil {
return err return err
} }
...@@ -122,11 +126,11 @@ func (c PubSubConn) Receive() interface{} { ...@@ -122,11 +126,11 @@ func (c PubSubConn) Receive() interface{} {
} }
return m return m
case "pmessage": case "pmessage":
var pm PMessage var m Message
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil { if _, err := Scan(reply, &m.Pattern, &m.Channel, &m.Data); err != nil {
return err return err
} }
return pm return m
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe": case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
s := Subscription{Kind: kind} s := Subscription{Kind: kind}
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil { if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
......
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"time"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value when the connection is not usable.
Err() error
// Do sends a command to the server and returns the received reply.
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server.
Flush() error
// Receive receives a single reply from the Redis server
Receive() (reply interface{}, err error)
}
// Argument is the interface implemented by an object which wants to control how
// the object is converted to Redis bulk strings.
type Argument interface {
// RedisArg returns a value to be encoded as a bulk string per the
// conversions listed in the section 'Executing Commands'.
// Implementations should typically return a []byte or string.
RedisArg() interface{}
}
// Scanner is implemented by an object which wants to control its value is
// interpreted when read from Redis.
type Scanner interface {
// RedisScan assigns a value from a Redis value. The argument src is one of
// the reply types listed in the section `Executing Commands`.
//
// An error should be returned if the value cannot be stored without
// loss of information.
RedisScan(src interface{}) error
}
// ConnWithTimeout is an optional interface that allows the caller to override
// a connection's default read timeout. This interface is useful for executing
// the BLPOP, BRPOP, BRPOPLPUSH, XREAD and other commands that block at the
// server.
//
// A connection's default read timeout is set with the DialReadTimeout dial
// option. Applications should rely on the default timeout for commands that do
// not block at the server.
//
// All of the Conn implementations in this package satisfy the ConnWithTimeout
// interface.
//
// Use the DoWithTimeout and ReceiveWithTimeout helper functions to simplify
// use of this interface.
type ConnWithTimeout interface {
Conn
// Do sends a command to the server and returns the received reply.
// The timeout overrides the read timeout set when dialing the
// connection.
DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)
// Receive receives a single reply from the Redis server. The timeout
// overrides the read timeout set when dialing the connection.
ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
}
var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
// DoWithTimeout executes a Redis command with the specified read timeout. If
// the connection does not satisfy the ConnWithTimeout interface, then an error
// is returned.
func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.DoWithTimeout(timeout, cmd, args...)
}
// ReceiveWithTimeout receives a reply with the specified read timeout. If the
// connection does not satisfy the ConnWithTimeout interface, then an error is
// returned.
func ReceiveWithTimeout(c Conn, timeout time.Duration) (interface{}, error) {
cwt, ok := c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.ReceiveWithTimeout(timeout)
}
...@@ -243,34 +243,67 @@ func Values(reply interface{}, err error) ([]interface{}, error) { ...@@ -243,34 +243,67 @@ func Values(reply interface{}, err error) ([]interface{}, error) {
return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply) return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply)
} }
// Strings is a helper that converts an array command reply to a []string. If func sliceHelper(reply interface{}, err error, name string, makeSlice func(int), assign func(int, interface{}) error) error {
// err is not equal to nil, then Strings returns nil, err. Nil array items are
// converted to "" in the output slice. Strings returns an error if an array
// item is not a bulk string or nil.
func Strings(reply interface{}, err error) ([]string, error) {
if err != nil { if err != nil {
return nil, err return err
} }
switch reply := reply.(type) { switch reply := reply.(type) {
case []interface{}: case []interface{}:
result := make([]string, len(reply)) makeSlice(len(reply))
for i := range reply { for i := range reply {
if reply[i] == nil { if reply[i] == nil {
continue continue
} }
p, ok := reply[i].([]byte) if err := assign(i, reply[i]); err != nil {
if !ok { return err
return nil, fmt.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
} }
result[i] = string(p)
} }
return result, nil return nil
case nil: case nil:
return nil, ErrNil return ErrNil
case Error: case Error:
return nil, reply return reply
} }
return nil, fmt.Errorf("redigo: unexpected type for Strings, got type %T", reply) return fmt.Errorf("redigo: unexpected type for %s, got type %T", name, reply)
}
// Float64s is a helper that converts an array command reply to a []float64. If
// err is not equal to nil, then Float64s returns nil, err. Nil array items are
// converted to 0 in the output slice. Floats64 returns an error if an array
// item is not a bulk string or nil.
func Float64s(reply interface{}, err error) ([]float64, error) {
var result []float64
err = sliceHelper(reply, err, "Float64s", func(n int) { result = make([]float64, n) }, func(i int, v interface{}) error {
p, ok := v.([]byte)
if !ok {
return fmt.Errorf("redigo: unexpected element type for Floats64, got type %T", v)
}
f, err := strconv.ParseFloat(string(p), 64)
result[i] = f
return err
})
return result, err
}
// Strings is a helper that converts an array command reply to a []string. If
// err is not equal to nil, then Strings returns nil, err. Nil array items are
// converted to "" in the output slice. Strings returns an error if an array
// item is not a bulk string or nil.
func Strings(reply interface{}, err error) ([]string, error) {
var result []string
err = sliceHelper(reply, err, "Strings", func(n int) { result = make([]string, n) }, func(i int, v interface{}) error {
switch v := v.(type) {
case string:
result[i] = v
return nil
case []byte:
result[i] = string(v)
return nil
default:
return fmt.Errorf("redigo: unexpected element type for Strings, got type %T", v)
}
})
return result, err
} }
// ByteSlices is a helper that converts an array command reply to a [][]byte. // ByteSlices is a helper that converts an array command reply to a [][]byte.
...@@ -278,43 +311,64 @@ func Strings(reply interface{}, err error) ([]string, error) { ...@@ -278,43 +311,64 @@ func Strings(reply interface{}, err error) ([]string, error) {
// items are stay nil. ByteSlices returns an error if an array item is not a // items are stay nil. ByteSlices returns an error if an array item is not a
// bulk string or nil. // bulk string or nil.
func ByteSlices(reply interface{}, err error) ([][]byte, error) { func ByteSlices(reply interface{}, err error) ([][]byte, error) {
if err != nil { var result [][]byte
return nil, err err = sliceHelper(reply, err, "ByteSlices", func(n int) { result = make([][]byte, n) }, func(i int, v interface{}) error {
} p, ok := v.([]byte)
switch reply := reply.(type) { if !ok {
case []interface{}: return fmt.Errorf("redigo: unexpected element type for ByteSlices, got type %T", v)
result := make([][]byte, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, fmt.Errorf("redigo: unexpected element type for ByteSlices, got type %T", reply[i])
}
result[i] = p
} }
return result, nil result[i] = p
case nil: return nil
return nil, ErrNil })
case Error: return result, err
return nil, reply }
}
return nil, fmt.Errorf("redigo: unexpected type for ByteSlices, got type %T", reply) // Int64s is a helper that converts an array command reply to a []int64.
// If err is not equal to nil, then Int64s returns nil, err. Nil array
// items are stay nil. Int64s returns an error if an array item is not a
// bulk string or nil.
func Int64s(reply interface{}, err error) ([]int64, error) {
var result []int64
err = sliceHelper(reply, err, "Int64s", func(n int) { result = make([]int64, n) }, func(i int, v interface{}) error {
switch v := v.(type) {
case int64:
result[i] = v
return nil
case []byte:
n, err := strconv.ParseInt(string(v), 10, 64)
result[i] = n
return err
default:
return fmt.Errorf("redigo: unexpected element type for Int64s, got type %T", v)
}
})
return result, err
} }
// Ints is a helper that converts an array command reply to a []int. If // Ints is a helper that converts an array command reply to a []in.
// err is not equal to nil, then Ints returns nil, err. // If err is not equal to nil, then Ints returns nil, err. Nil array
// items are stay nil. Ints returns an error if an array item is not a
// bulk string or nil.
func Ints(reply interface{}, err error) ([]int, error) { func Ints(reply interface{}, err error) ([]int, error) {
var ints []int var result []int
values, err := Values(reply, err) err = sliceHelper(reply, err, "Ints", func(n int) { result = make([]int, n) }, func(i int, v interface{}) error {
if err != nil { switch v := v.(type) {
return ints, err case int64:
} n := int(v)
if err := ScanSlice(values, &ints); err != nil { if int64(n) != v {
return ints, err return strconv.ErrRange
} }
return ints, nil result[i] = n
return nil
case []byte:
n, err := strconv.Atoi(string(v))
result[i] = n
return err
default:
return fmt.Errorf("redigo: unexpected element type for Ints, got type %T", v)
}
})
return result, err
} }
// StringMap is a helper that converts an array of strings (alternating key, value) // StringMap is a helper that converts an array of strings (alternating key, value)
...@@ -333,7 +387,7 @@ func StringMap(result interface{}, err error) (map[string]string, error) { ...@@ -333,7 +387,7 @@ func StringMap(result interface{}, err error) (map[string]string, error) {
key, okKey := values[i].([]byte) key, okKey := values[i].([]byte)
value, okValue := values[i+1].([]byte) value, okValue := values[i+1].([]byte)
if !okKey || !okValue { if !okKey || !okValue {
return nil, errors.New("redigo: ScanMap key not a bulk string value") return nil, errors.New("redigo: StringMap key not a bulk string value")
} }
m[string(key)] = string(value) m[string(key)] = string(value)
} }
...@@ -355,7 +409,7 @@ func IntMap(result interface{}, err error) (map[string]int, error) { ...@@ -355,7 +409,7 @@ func IntMap(result interface{}, err error) (map[string]int, error) {
for i := 0; i < len(values); i += 2 { for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte) key, ok := values[i].([]byte)
if !ok { if !ok {
return nil, errors.New("redigo: ScanMap key not a bulk string value") return nil, errors.New("redigo: IntMap key not a bulk string value")
} }
value, err := Int(values[i+1], nil) value, err := Int(values[i+1], nil)
if err != nil { if err != nil {
...@@ -381,7 +435,7 @@ func Int64Map(result interface{}, err error) (map[string]int64, error) { ...@@ -381,7 +435,7 @@ func Int64Map(result interface{}, err error) (map[string]int64, error) {
for i := 0; i < len(values); i += 2 { for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte) key, ok := values[i].([]byte)
if !ok { if !ok {
return nil, errors.New("redigo: ScanMap key not a bulk string value") return nil, errors.New("redigo: Int64Map key not a bulk string value")
} }
value, err := Int64(values[i+1], nil) value, err := Int64(values[i+1], nil)
if err != nil { if err != nil {
...@@ -391,3 +445,35 @@ func Int64Map(result interface{}, err error) (map[string]int64, error) { ...@@ -391,3 +445,35 @@ func Int64Map(result interface{}, err error) (map[string]int64, error) {
} }
return m, nil return m, nil
} }
// Positions is a helper that converts an array of positions (lat, long)
// into a [][2]float64. The GEOPOS command returns replies in this format.
func Positions(result interface{}, err error) ([]*[2]float64, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
positions := make([]*[2]float64, len(values))
for i := range values {
if values[i] == nil {
continue
}
p, ok := values[i].([]interface{})
if !ok {
return nil, fmt.Errorf("redigo: unexpected element type for interface slice, got type %T", values[i])
}
if len(p) != 2 {
return nil, fmt.Errorf("redigo: unexpected number of values for a member position, got %d", len(p))
}
lat, err := Float64(p[0], nil)
if err != nil {
return nil, err
}
long, err := Float64(p[1], nil)
if err != nil {
return nil, err
}
positions[i] = &[2]float64{lat, long}
}
return positions, nil
}
...@@ -110,6 +110,25 @@ func convertAssignInt(d reflect.Value, s int64) (err error) { ...@@ -110,6 +110,25 @@ func convertAssignInt(d reflect.Value, s int64) (err error) {
} }
func convertAssignValue(d reflect.Value, s interface{}) (err error) { func convertAssignValue(d reflect.Value, s interface{}) (err error) {
if d.Kind() != reflect.Ptr {
if d.CanAddr() {
d2 := d.Addr()
if d2.CanInterface() {
if scanner, ok := d2.Interface().(Scanner); ok {
return scanner.RedisScan(s)
}
}
}
} else if d.CanInterface() {
// Already a reflect.Ptr
if d.IsNil() {
d.Set(reflect.New(d.Type().Elem()))
}
if scanner, ok := d.Interface().(Scanner); ok {
return scanner.RedisScan(s)
}
}
switch s := s.(type) { switch s := s.(type) {
case []byte: case []byte:
err = convertAssignBulkString(d, s) err = convertAssignBulkString(d, s)
...@@ -135,11 +154,15 @@ func convertAssignArray(d reflect.Value, s []interface{}) error { ...@@ -135,11 +154,15 @@ func convertAssignArray(d reflect.Value, s []interface{}) error {
} }
func convertAssign(d interface{}, s interface{}) (err error) { func convertAssign(d interface{}, s interface{}) (err error) {
if scanner, ok := d.(Scanner); ok {
return scanner.RedisScan(s)
}
// Handle the most common destination types using type switches and // Handle the most common destination types using type switches and
// fall back to reflection for all other types. // fall back to reflection for all other types.
switch s := s.(type) { switch s := s.(type) {
case nil: case nil:
// ingore // ignore
case []byte: case []byte:
switch d := d.(type) { switch d := d.(type) {
case *string: case *string:
...@@ -186,7 +209,11 @@ func convertAssign(d interface{}, s interface{}) (err error) { ...@@ -186,7 +209,11 @@ func convertAssign(d interface{}, s interface{}) (err error) {
case string: case string:
switch d := d.(type) { switch d := d.(type) {
case *string: case *string:
*d = string(s) *d = s
case *interface{}:
*d = s
case nil:
// skip value
default: default:
err = cannotConvert(reflect.ValueOf(d), s) err = cannotConvert(reflect.ValueOf(d), s)
} }
...@@ -215,6 +242,8 @@ func convertAssign(d interface{}, s interface{}) (err error) { ...@@ -215,6 +242,8 @@ func convertAssign(d interface{}, s interface{}) (err error) {
// Scan copies from src to the values pointed at by dest. // Scan copies from src to the values pointed at by dest.
// //
// Scan uses RedisScan if available otherwise:
//
// The values pointed at by dest must be an integer, float, boolean, string, // The values pointed at by dest must be an integer, float, boolean, string,
// []byte, interface{} or slices of these types. Scan uses the standard strconv // []byte, interface{} or slices of these types. Scan uses the standard strconv
// package to convert bulk strings to numeric and boolean types. // package to convert bulk strings to numeric and boolean types.
...@@ -355,6 +384,7 @@ var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil po ...@@ -355,6 +384,7 @@ var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil po
// //
// Fields with the tag redis:"-" are ignored. // Fields with the tag redis:"-" are ignored.
// //
// Each field uses RedisScan if available otherwise:
// Integer, float, boolean, string and []byte fields are supported. Scan uses the // Integer, float, boolean, string and []byte fields are supported. Scan uses the
// standard strconv package to convert bulk string values to numeric and // standard strconv package to convert bulk string values to numeric and
// boolean types. // boolean types.
......
...@@ -55,6 +55,11 @@ func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} { ...@@ -55,6 +55,11 @@ func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
return args return args
} }
// Hash returns the script hash.
func (s *Script) Hash() string {
return s.hash
}
// Do evaluates the script. Under the covers, Do optimistically evaluates the // Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is // script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus // not loaded, then Do evaluates the script using the EVAL command (thus
......
...@@ -4,5 +4,6 @@ ...@@ -4,5 +4,6 @@
# Please keep the list sorted. # Please keep the list sorted.
Gary Burd <gary@beagledreams.com> Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de> Joachim Bauch <mail@joachim-bauch.de>
...@@ -51,7 +51,7 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn ...@@ -51,7 +51,7 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr> <tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table> </table>
Notes: Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). 1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing 2. The application can get the type of a received data message by implementing
......
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// +build appengine
package websocket
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment