Commit c54b5d03 authored by Rob Pike's avatar Rob Pike

gob: make recursive map and slice types work.

Before this fix, types such as
        type T map[string]T
caused infinite recursion in the gob implementation.
Now they just work.

Fixes #1518.

R=rsc
CC=golang-dev
https://golang.org/cl/4230045
parent 89563177
......@@ -342,7 +342,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct {
a int
}
instr := &decInstr{decOpMap[reflect.Int], 6, 0, 0, ovfl}
instr := &decInstr{decOpTable[reflect.Int], 6, 0, 0, ovfl}
state := newDecodeStateFromData(signedResult)
execDec("int", instr, state, t, unsafe.Pointer(&data))
if data.a != 17 {
......@@ -355,7 +355,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct {
a uint
}
instr := &decInstr{decOpMap[reflect.Uint], 6, 0, 0, ovfl}
instr := &decInstr{decOpTable[reflect.Uint], 6, 0, 0, ovfl}
state := newDecodeStateFromData(unsignedResult)
execDec("uint", instr, state, t, unsafe.Pointer(&data))
if data.a != 17 {
......@@ -446,7 +446,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct {
a uintptr
}
instr := &decInstr{decOpMap[reflect.Uintptr], 6, 0, 0, ovfl}
instr := &decInstr{decOpTable[reflect.Uintptr], 6, 0, 0, ovfl}
state := newDecodeStateFromData(unsignedResult)
execDec("uintptr", instr, state, t, unsafe.Pointer(&data))
if data.a != 17 {
......@@ -511,7 +511,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct {
a complex64
}
instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl}
instr := &decInstr{decOpTable[reflect.Complex64], 6, 0, 0, ovfl}
state := newDecodeStateFromData(complexResult)
execDec("complex", instr, state, t, unsafe.Pointer(&data))
if data.a != 17+19i {
......@@ -524,7 +524,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct {
a complex128
}
instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl}
instr := &decInstr{decOpTable[reflect.Complex128], 6, 0, 0, ovfl}
state := newDecodeStateFromData(complexResult)
execDec("complex", instr, state, t, unsafe.Pointer(&data))
if data.a != 17+19i {
......
......@@ -671,7 +671,7 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
}
// Index by Go types.
var decOpMap = []decOp{
var decOpTable = [...]decOp{
reflect.Bool: decBool,
reflect.Int8: decInt8,
reflect.Int16: decInt16,
......@@ -701,37 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and
// the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
ut := userType(rt)
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base
indir := ut.indir
var op decOp
k := typ.Kind()
if int(k) < len(decOpMap) {
op = decOpMap[k]
if int(k) < len(decOpTable) {
op = decOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ.(type) {
case *reflect.ArrayType:
name = "element of " + name
elemId := dec.wireType[wireId].ArrayT.Elem
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
}
case *reflect.MapType:
name = "element of " + name
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
up := unsafe.Pointer(p)
state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
}
case *reflect.SliceType:
......@@ -746,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} else {
elemId = dec.wireType[wireId].SliceT.Elem
}
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
}
case *reflect.StructType:
......@@ -774,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
if op == nil {
errorf("gob: decode can't handle type %s", rt.String())
}
return op, indir
return &op, indir
}
// Return the decoding op for a field that has no destination.
......@@ -838,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
// Are these two gob Types compatible?
// Answers the question for basic types, arrays, and slices.
// Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok {
return rhs == fw
}
inProgress[fr] = fw
fr = userType(fr).base
switch t := fr.(type) {
default:
// map, chan, etc: cannot handle.
// chan, etc: cannot handle.
return false
case *reflect.BoolType:
return fw == tBool
......@@ -864,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return false
}
array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case *reflect.MapType:
wire, ok := dec.wireType[fw]
if !ok || wire.MapT == nil {
return false
}
MapType := wire.MapT
return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem)
return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
case *reflect.SliceType:
// Is it an array of bytes?
if t.Elem().Kind() == reflect.Uint8 {
......@@ -885,7 +895,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
sw = dec.wireType[fw].SliceT
}
elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem)
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
case *reflect.StructType:
return true
}
......@@ -906,12 +916,12 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do
if !dec.compatibleType(rt, remoteId) {
if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId))
}
op, indir := dec.decOpFor(remoteId, rt, name)
op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
ovfl := os.ErrorString(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl}
engine.numInstr = 1
return
}
......@@ -954,6 +964,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
}
engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.Field))
seen := make(map[reflect.Type]*decOp)
// Loop over the fields of the wire type.
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
wireField := wireStruct.Field[fieldnum]
......@@ -969,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl}
continue
}
if !dec.compatibleType(localField.Type, wireField.Id) {
if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
}
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name)
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl}
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl}
engine.numInstr++
}
return
......@@ -1070,8 +1081,8 @@ func init() {
default:
panic("gob: unknown size of int/uint")
}
decOpMap[reflect.Int] = iop
decOpMap[reflect.Uint] = uop
decOpTable[reflect.Int] = iop
decOpTable[reflect.Uint] = uop
// Finally uintptr
switch reflect.Typeof(uintptr(0)).Bits() {
......@@ -1082,5 +1093,5 @@ func init() {
default:
panic("gob: unknown size of uintptr")
}
decOpMap[reflect.Uintptr] = uop
decOpTable[reflect.Uintptr] = uop
}
......@@ -414,7 +414,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
}
}
var encOpMap = []encOp{
var encOpTable = [...]encOp{
reflect.Bool: encBool,
reflect.Int: encInt,
reflect.Int8: encInt8,
......@@ -434,18 +434,24 @@ var encOpMap = []encOp{
reflect.String: encString,
}
// Return the encoding op for the base type under rt and
// Return (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
ut := userType(rt)
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base
indir := ut.indir
var op encOp
k := typ.Kind()
if int(k) < len(encOpMap) {
op = encOpMap[k]
var op encOp
if int(k) < len(encOpTable) {
op = encOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ.(type) {
case *reflect.SliceType:
......@@ -454,25 +460,25 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
break
}
// Slices have a header; we decode it to find the underlying array.
elemOp, indir := enc.encOpFor(t.Elem())
elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
slice := (*reflect.SliceHeader)(p)
if !state.sendZero && slice.Len == 0 {
return
}
state.update(i)
state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len))
}
case *reflect.ArrayType:
// True arrays have size in the type.
elemOp, indir := enc.encOpFor(t.Elem())
elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i)
state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len())
}
case *reflect.MapType:
keyOp, keyIndir := enc.encOpFor(t.Key())
elemOp, elemIndir := enc.encOpFor(t.Elem())
keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress)
elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for
......@@ -483,7 +489,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
return
}
state.update(i)
state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir)
state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir)
}
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
......@@ -511,21 +517,22 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
if op == nil {
errorf("gob enc: can't happen: encode type %s", rt.String())
}
return op, indir
return &op, indir
}
// The local Type was compiled from the actual value, so we know it's compatible.
func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
srt, isStruct := rt.(*reflect.StructType)
engine := new(encEngine)
seen := make(map[reflect.Type]*encOp)
if isStruct {
for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum)
if !isExported(f.Name) {
continue
}
op, indir := enc.encOpFor(f.Type)
engine.instr = append(engine.instr, encInstr{op, fieldNum, indir, uintptr(f.Offset)})
op, indir := enc.encOpFor(f.Type, seen)
engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)})
}
if srt.NumField() > 0 && len(engine.instr) == 0 {
errorf("type %s has no exported fields", rt)
......@@ -533,8 +540,8 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0})
} else {
engine.instr = make([]encInstr, 1)
op, indir := enc.encOpFor(rt)
engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero
op, indir := enc.encOpFor(rt, seen)
engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero
}
return engine
}
......
......@@ -249,6 +249,24 @@ func TestArray(t *testing.T) {
}
}
func TestRecursiveMapType(t *testing.T) {
type recursiveMap map[string]recursiveMap
r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil}
r2 := make(recursiveMap)
if err := encAndDec(r1, &r2); err != nil {
t.Error(err)
}
}
func TestRecursiveSliceType(t *testing.T) {
type recursiveSlice []recursiveSlice
r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil}
r2 := make(recursiveSlice, 0)
if err := encAndDec(r1, &r2); err != nil {
t.Error(err)
}
}
// Regression test for bug: must send zero values inside arrays
func TestDefaultsInArray(t *testing.T) {
type Type7 struct {
......
......@@ -52,9 +52,6 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle.
// TODO: still need to deal with self-referential non-structs such
// as type T map[string]T but that is a larger undertaking - and can
// be useful, not always erroneous.
slowpoke := ut.base // walks half as fast as ut.base
for {
pt, ok := ut.base.(*reflect.PtrType)
......@@ -210,12 +207,18 @@ type arrayType struct {
Len int
}
func newArrayType(name string, elem gobType, length int) *arrayType {
a := &arrayType{CommonType{Name: name}, elem.id(), length}
setTypeId(a)
func newArrayType(name string) *arrayType {
a := &arrayType{CommonType{Name: name}, 0, 0}
return a
}
func (a *arrayType) init(elem gobType, len int) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(a)
a.Elem = elem.id()
a.Len = len
}
func (a *arrayType) safeString(seen map[typeId]bool) string {
if seen[a.Id] {
return a.Name
......@@ -233,12 +236,18 @@ type mapType struct {
Elem typeId
}
func newMapType(name string, key, elem gobType) *mapType {
m := &mapType{CommonType{Name: name}, key.id(), elem.id()}
setTypeId(m)
func newMapType(name string) *mapType {
m := &mapType{CommonType{Name: name}, 0, 0}
return m
}
func (m *mapType) init(key, elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(m)
m.Key = key.id()
m.Elem = elem.id()
}
func (m *mapType) safeString(seen map[typeId]bool) string {
if seen[m.Id] {
return m.Name
......@@ -257,12 +266,17 @@ type sliceType struct {
Elem typeId
}
func newSliceType(name string, elem gobType) *sliceType {
s := &sliceType{CommonType{Name: name}, elem.id()}
setTypeId(s)
func newSliceType(name string) *sliceType {
s := &sliceType{CommonType{Name: name}, 0}
return s
}
func (s *sliceType) init(elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(s)
s.Elem = elem.id()
}
func (s *sliceType) safeString(seen map[typeId]bool) string {
if seen[s.Id] {
return s.Name
......@@ -304,11 +318,26 @@ func (s *structType) string() string { return s.safeString(make(map[typeId]bool)
func newStructType(name string) *structType {
s := &structType{CommonType{Name: name}, nil}
// For historical reasons we set the id here rather than init.
// Se the comment in newTypeObject for details.
setTypeId(s)
return s
}
func (s *structType) init(field []*fieldType) {
s.Field = field
}
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
var err os.Error
var type0, type1 gobType
defer func() {
if err != nil {
types[rt] = nil, false
}
}()
// Install the top-level type before the subtypes (e.g. struct before
// fields) so recursive types can be constructed safely.
switch t := rt.(type) {
// All basic types are easy: they are predefined.
case *reflect.BoolType:
......@@ -333,40 +362,55 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
return tInterface.gobType(), nil
case *reflect.ArrayType:
gt, err := getType("", t.Elem())
at := newArrayType(name)
types[rt] = at
type0, err = getType("", t.Elem())
if err != nil {
return nil, err
}
return newArrayType(name, gt, t.Len()), nil
// Historical aside:
// For arrays, maps, and slices, we set the type id after the elements
// are constructed. This is to retain the order of type id allocation after
// a fix made to handle recursive types, which changed the order in
// which types are built. Delaying the setting in this way preserves
// type ids while allowing recursive types to be described. Structs,
// done below, were already handling recursion correctly so they
// assign the top-level id before those of the field.
at.init(type0, t.Len())
return at, nil
case *reflect.MapType:
kt, err := getType("", t.Key())
mt := newMapType(name)
types[rt] = mt
type0, err = getType("", t.Key())
if err != nil {
return nil, err
}
vt, err := getType("", t.Elem())
type1, err = getType("", t.Elem())
if err != nil {
return nil, err
}
return newMapType(name, kt, vt), nil
mt.init(type0, type1)
return mt, nil
case *reflect.SliceType:
// []byte == []uint8 is a special case
if t.Elem().Kind() == reflect.Uint8 {
return tBytes.gobType(), nil
}
gt, err := getType(t.Elem().Name(), t.Elem())
st := newSliceType(name)
types[rt] = st
type0, err = getType(t.Elem().Name(), t.Elem())
if err != nil {
return nil, err
}
return newSliceType(name, gt), nil
st.init(type0)
return st, nil
case *reflect.StructType:
// Install the struct type itself before the fields so recursive
// structures can be constructed safely.
strType := newStructType(name)
types[rt] = strType
idToType[strType.id()] = strType
st := newStructType(name)
types[rt] = st
idToType[st.id()] = st
field := make([]*fieldType, t.NumField())
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
......@@ -382,8 +426,8 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
}
field[i] = &fieldType{f.Name, gt.id()}
}
strType.Field = field
return strType, nil
st.init(field)
return st, nil
default:
return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String())
......@@ -435,7 +479,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
// For bootstrapping purposes, we assume that the recipient knows how
// to decode a wireType; it is exactly the wireType struct here, interpreted
// using the gob rules for sending a structure, except that we assume the
// ids for wireType and structType are known. The relevant pieces
// ids for wireType and structType etc. are known. The relevant pieces
// are built in encode.go's init() function.
// To maintain binary compatibility, if you extend this type, always put
// the new fields last.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment