Commit c9b90c9d authored by Rob Pike's avatar Rob Pike

gob: protect against pure recursive types.

There are further changes required for things like
recursive map types.  Recursive struct types work
but the mechanism needs generalization.  The
case handled in this CL is pathological since it
cannot be represented at all by gob, so it should
be handled separately. (Prior to this CL, encode
would recur forever.)

R=rsc
CC=golang-dev
https://golang.org/cl/4206041
parent da8e6eec
......@@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) {
}
}
func TestBadRecursiveType(t *testing.T) {
type Rec ***Rec
var rec Rec
b := new(bytes.Buffer)
err := NewEncoder(b).Encode(&rec)
if err == nil {
t.Error("expected error; got none")
} else if strings.Index(err.String(), "recursive") < 0 {
t.Error("expected recursive type error; got", err)
}
// Can't test decode easily because we can't encode one, so we can't pass one to a Decoder.
}
type Bad0 struct {
ch chan int
c float64
CH chan int
C float64
}
var nilEncoder *Encoder
func TestInvalidField(t *testing.T) {
var bad0 Bad0
bad0.ch = make(chan int)
bad0.CH = make(chan int)
b := new(bytes.Buffer)
var nilEncoder *Encoder
err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err == nil {
t.Error("expected error; got none")
......
......@@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
}
func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
defer catchError(&err)
p = allocate(ut.base, p, ut.indir)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
......@@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
// This state cannot arise for decodeSingle, which is called directly
// from the user's value, not from the innards of an engine.
func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
defer catchError(&err)
p = allocate(ut.base.(*reflect.StructType), p, indir)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
......@@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
}
func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
defer catchError(&err)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
for state.b.Len() > 0 {
......@@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
}
func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
defer catchError(&err)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
delta := int(state.decodeUint())
......@@ -937,7 +933,6 @@ func isExported(name string) bool {
}
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
defer catchError(&err)
srt, ok := rt.(*reflect.StructType)
if !ok {
return dec.compileSingle(remoteId, rt)
......@@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
return
}
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
defer catchError(&err)
// If the value is nil, it means we should just ignore this item.
if val == nil {
return dec.decodeIgnoredValue(wireId)
......
......@@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
// Remove any nested writers remaining due to previous errors.
enc.w = enc.w[0:1]
enc.err = nil
ut := userType(value.Type())
ut, err := validUserType(value.Type())
if err != nil {
return err
}
enc.err = nil
state := newEncoderState(enc, new(bytes.Buffer))
enc.sendTypeDescriptor(enc.writer(), state, ut)
......@@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
}
// Encode the object.
err := enc.encode(state.b, value, ut)
err = enc.encode(state.b, value, ut)
if err != nil {
enc.setError(err)
} else {
......
......@@ -27,28 +27,63 @@ var (
userTypeCache = make(map[reflect.Type]*userTypeInfo)
)
// userType returns, and saves, the information associated with user-provided type rt
func userType(rt reflect.Type) *userTypeInfo {
// validType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, err will be non-nil. To be used when the error handler
// is not set up.
func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
userTypeLock.RLock()
ut := userTypeCache[rt]
ut = userTypeCache[rt]
userTypeLock.RUnlock()
if ut != nil {
return ut
return
}
// Now set the value under the write lock.
userTypeLock.Lock()
defer userTypeLock.Unlock()
if ut = userTypeCache[rt]; ut != nil {
// Lost the race; not a problem.
return ut
return
}
ut = new(userTypeInfo)
ut.base = rt
ut.user = rt
ut.base, ut.indir = indirect(rt)
// A type that is just a cycle of pointers (such as type T *T) cannot
// be represented in gobs, which need some concrete data. We use a
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle.
// TODO: still need to deal with self-referential non-structs such
// as type T map[string]T but that is a larger undertaking - and can
// be useful, not always erroneous.
slowpoke := ut.base // walks half as fast as ut.base
for {
pt, ok := ut.base.(*reflect.PtrType)
if !ok {
break
}
ut.base = pt.Elem()
if ut.base == slowpoke { // ut.base lapped slowpoke
// recursive pointer type.
return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String())
}
if ut.indir%2 == 0 {
slowpoke = slowpoke.(*reflect.PtrType).Elem()
}
ut.indir++
}
userTypeCache[rt] = ut
return ut
return
}
// userType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, it calls error.
func userType(rt reflect.Type) *userTypeInfo {
ut, err := validUserType(rt)
if err != nil {
error(err)
}
return ut
}
// A typeId represents a gob Type as an integer that can be passed on the wire.
// Internally, typeIds are used as keys to a map to recover the underlying type info.
type typeId int32
......@@ -273,21 +308,6 @@ func newStructType(name string) *structType {
return s
}
// Step through the indirections on a type to discover the base type.
// Return the base type and the number of indirections.
func indirect(t reflect.Type) (rt reflect.Type, count int) {
rt = t
for {
pt, ok := rt.(*reflect.PtrType)
if !ok {
break
}
rt = pt.Elem()
count++
}
return
}
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
switch t := rt.(type) {
// All basic types are easy: they are predefined.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment