Commit c9b90c9d authored by Rob Pike's avatar Rob Pike

gob: protect against pure recursive types.

There are further changes required for things like
recursive map types.  Recursive struct types work
but the mechanism needs generalization.  The
case handled in this CL is pathological since it
cannot be represented at all by gob, so it should
be handled separately. (Prior to this CL, encode
would recur forever.)

R=rsc
CC=golang-dev
https://golang.org/cl/4206041
parent da8e6eec
...@@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) { ...@@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) {
} }
} }
func TestBadRecursiveType(t *testing.T) {
type Rec ***Rec
var rec Rec
b := new(bytes.Buffer)
err := NewEncoder(b).Encode(&rec)
if err == nil {
t.Error("expected error; got none")
} else if strings.Index(err.String(), "recursive") < 0 {
t.Error("expected recursive type error; got", err)
}
// Can't test decode easily because we can't encode one, so we can't pass one to a Decoder.
}
type Bad0 struct { type Bad0 struct {
ch chan int CH chan int
c float64 C float64
} }
var nilEncoder *Encoder
func TestInvalidField(t *testing.T) { func TestInvalidField(t *testing.T) {
var bad0 Bad0 var bad0 Bad0
bad0.ch = make(chan int) bad0.CH = make(chan int)
b := new(bytes.Buffer) b := new(bytes.Buffer)
var nilEncoder *Encoder
err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err == nil { if err == nil {
t.Error("expected error; got none") t.Error("expected error; got none")
......
...@@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { ...@@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
} }
func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
defer catchError(&err)
p = allocate(ut.base, p, ut.indir) p = allocate(ut.base, p, ut.indir)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField state.fieldnum = singletonField
...@@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) ...@@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
// This state cannot arise for decodeSingle, which is called directly // This state cannot arise for decodeSingle, which is called directly
// from the user's value, not from the innards of an engine. // from the user's value, not from the innards of an engine.
func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) { func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
defer catchError(&err)
p = allocate(ut.base.(*reflect.StructType), p, indir) p = allocate(ut.base.(*reflect.StructType), p, indir)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1 state.fieldnum = -1
...@@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, ...@@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
} }
func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
defer catchError(&err)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1 state.fieldnum = -1
for state.b.Len() > 0 { for state.b.Len() > 0 {
...@@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { ...@@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
} }
func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
defer catchError(&err)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField state.fieldnum = singletonField
delta := int(state.decodeUint()) delta := int(state.decodeUint())
...@@ -937,7 +933,6 @@ func isExported(name string) bool { ...@@ -937,7 +933,6 @@ func isExported(name string) bool {
} }
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
defer catchError(&err)
srt, ok := rt.(*reflect.StructType) srt, ok := rt.(*reflect.StructType)
if !ok { if !ok {
return dec.compileSingle(remoteId, rt) return dec.compileSingle(remoteId, rt)
...@@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er ...@@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
return return
} }
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error { func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
defer catchError(&err)
// If the value is nil, it means we should just ignore this item. // If the value is nil, it means we should just ignore this item.
if val == nil { if val == nil {
return dec.decodeIgnoredValue(wireId) return dec.decodeIgnoredValue(wireId)
......
...@@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { ...@@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
// Remove any nested writers remaining due to previous errors. // Remove any nested writers remaining due to previous errors.
enc.w = enc.w[0:1] enc.w = enc.w[0:1]
enc.err = nil ut, err := validUserType(value.Type())
ut := userType(value.Type()) if err != nil {
return err
}
enc.err = nil
state := newEncoderState(enc, new(bytes.Buffer)) state := newEncoderState(enc, new(bytes.Buffer))
enc.sendTypeDescriptor(enc.writer(), state, ut) enc.sendTypeDescriptor(enc.writer(), state, ut)
...@@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { ...@@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
} }
// Encode the object. // Encode the object.
err := enc.encode(state.b, value, ut) err = enc.encode(state.b, value, ut)
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
} else { } else {
......
...@@ -27,28 +27,63 @@ var ( ...@@ -27,28 +27,63 @@ var (
userTypeCache = make(map[reflect.Type]*userTypeInfo) userTypeCache = make(map[reflect.Type]*userTypeInfo)
) )
// userType returns, and saves, the information associated with user-provided type rt // validType returns, and saves, the information associated with user-provided type rt.
func userType(rt reflect.Type) *userTypeInfo { // If the user type is not valid, err will be non-nil. To be used when the error handler
// is not set up.
func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
userTypeLock.RLock() userTypeLock.RLock()
ut := userTypeCache[rt] ut = userTypeCache[rt]
userTypeLock.RUnlock() userTypeLock.RUnlock()
if ut != nil { if ut != nil {
return ut return
} }
// Now set the value under the write lock. // Now set the value under the write lock.
userTypeLock.Lock() userTypeLock.Lock()
defer userTypeLock.Unlock() defer userTypeLock.Unlock()
if ut = userTypeCache[rt]; ut != nil { if ut = userTypeCache[rt]; ut != nil {
// Lost the race; not a problem. // Lost the race; not a problem.
return ut return
} }
ut = new(userTypeInfo) ut = new(userTypeInfo)
ut.base = rt
ut.user = rt ut.user = rt
ut.base, ut.indir = indirect(rt) // A type that is just a cycle of pointers (such as type T *T) cannot
// be represented in gobs, which need some concrete data. We use a
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle.
// TODO: still need to deal with self-referential non-structs such
// as type T map[string]T but that is a larger undertaking - and can
// be useful, not always erroneous.
slowpoke := ut.base // walks half as fast as ut.base
for {
pt, ok := ut.base.(*reflect.PtrType)
if !ok {
break
}
ut.base = pt.Elem()
if ut.base == slowpoke { // ut.base lapped slowpoke
// recursive pointer type.
return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String())
}
if ut.indir%2 == 0 {
slowpoke = slowpoke.(*reflect.PtrType).Elem()
}
ut.indir++
}
userTypeCache[rt] = ut userTypeCache[rt] = ut
return ut return
} }
// userType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, it calls error.
func userType(rt reflect.Type) *userTypeInfo {
ut, err := validUserType(rt)
if err != nil {
error(err)
}
return ut
}
// A typeId represents a gob Type as an integer that can be passed on the wire. // A typeId represents a gob Type as an integer that can be passed on the wire.
// Internally, typeIds are used as keys to a map to recover the underlying type info. // Internally, typeIds are used as keys to a map to recover the underlying type info.
type typeId int32 type typeId int32
...@@ -273,21 +308,6 @@ func newStructType(name string) *structType { ...@@ -273,21 +308,6 @@ func newStructType(name string) *structType {
return s return s
} }
// Step through the indirections on a type to discover the base type.
// Return the base type and the number of indirections.
func indirect(t reflect.Type) (rt reflect.Type, count int) {
rt = t
for {
pt, ok := rt.(*reflect.PtrType)
if !ok {
break
}
rt = pt.Elem()
count++
}
return
}
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
switch t := rt.(type) { switch t := rt.(type) {
// All basic types are easy: they are predefined. // All basic types are easy: they are predefined.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment