Commit 7861da73 authored by Rob Pike's avatar Rob Pike

gob: add support for maps.

Because maps are mostly a hidden type, they must be
implemented using reflection values and will not be as
efficient as arrays and slices.

R=rsc
CC=golang-dev
https://golang.org/cl/1127041
parent 46152bb9
...@@ -572,6 +572,7 @@ func TestEndToEnd(t *testing.T) { ...@@ -572,6 +572,7 @@ func TestEndToEnd(t *testing.T) {
s2 := "string2" s2 := "string2"
type T1 struct { type T1 struct {
a, b, c int a, b, c int
m map[string]*float
n *[3]float n *[3]float
strs *[2]string strs *[2]string
int64s *[]int64 int64s *[]int64
...@@ -579,10 +580,13 @@ func TestEndToEnd(t *testing.T) { ...@@ -579,10 +580,13 @@ func TestEndToEnd(t *testing.T) {
y []byte y []byte
t *T2 t *T2
} }
pi := 3.14159
e := 2.71828
t1 := &T1{ t1 := &T1{
a: 17, a: 17,
b: 18, b: 18,
c: -5, c: -5,
m: map[string]*float{"pi": &pi, "e": &e},
n: &[3]float{1.5, 2.5, 3.5}, n: &[3]float{1.5, 2.5, 3.5},
strs: &[2]string{s1, s2}, strs: &[2]string{s1, s2},
int64s: &[]int64{77, 89, 123412342134}, int64s: &[]int64{77, 89, 123412342134},
...@@ -921,6 +925,7 @@ type IT0 struct { ...@@ -921,6 +925,7 @@ type IT0 struct {
ignore_g string ignore_g string
ignore_h []byte ignore_h []byte
ignore_i *RT1 ignore_i *RT1
ignore_m map[string]int
c float c float
} }
...@@ -937,6 +942,7 @@ func TestIgnoredFields(t *testing.T) { ...@@ -937,6 +942,7 @@ func TestIgnoredFields(t *testing.T) {
it0.ignore_g = "pay no attention" it0.ignore_g = "pay no attention"
it0.ignore_h = []byte("to the curtain") it0.ignore_h = []byte("to the curtain")
it0.ignore_i = &RT1{3.1, "hi", 7, "hello"} it0.ignore_i = &RT1{3.1, "hi", 7, "hello"}
it0.ignore_m = map[string]int{"one": 1, "two": 2}
b := new(bytes.Buffer) b := new(bytes.Buffer)
NewEncoder(b).Encode(it0) NewEncoder(b).Encode(it0)
......
...@@ -447,6 +447,49 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp ...@@ -447,6 +447,49 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp
return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
} }
func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
instr := &decInstr{op, 0, indir, 0, ovfl}
up := unsafe.Pointer(v.Addr())
if indir > 1 {
up = decIndirect(up, indir)
}
op(instr, state, up)
return v
}
func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error {
if indir > 0 {
up := unsafe.Pointer(p)
if *(*unsafe.Pointer)(up) == nil {
// Allocate object.
*(*unsafe.Pointer)(up) = unsafe.New(mtyp)
}
p = *(*uintptr)(up)
}
up := unsafe.Pointer(p)
if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime
// Allocate map.
*(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Get())
}
// Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for
// the iteration.
v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue)
n := int(decodeUint(state))
for i := 0; i < n && state.err == nil; i++ {
key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
if state.err != nil {
break
}
elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl)
if state.err != nil {
break
}
v.SetElem(key, elem)
}
return state.err
}
func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error { func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error {
instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
for i := 0; i < length && state.err == nil; i++ { for i := 0; i < length && state.err == nil; i++ {
...@@ -462,6 +505,18 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error { ...@@ -462,6 +505,18 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error {
return ignoreArrayHelper(state, elemOp, length) return ignoreArrayHelper(state, elemOp, length)
} }
func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error {
n := int(decodeUint(state))
keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
for i := 0; i < n && state.err == nil; i++ {
keyOp(keyInstr, state, nil)
elemOp(elemInstr, state, nil)
}
return state.err
}
func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error { func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error {
n := int(uintptr(decodeUint(state))) n := int(uintptr(decodeUint(state)))
if indir > 0 { if indir > 0 {
...@@ -517,17 +572,25 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -517,17 +572,25 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
if !ok { if !ok {
// Special cases // Special cases
switch t := typ.(type) { switch t := typ.(type) {
case *reflect.SliceType: case *reflect.ArrayType:
name = "element of " + name name = "element of " + name
if _, ok := t.Elem().(*reflect.Uint8Type); ok { elemId := dec.wireType[wireId].arrayT.Elem
op = decUint8Array elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
break if err != nil {
return nil, 0, err
} }
var elemId typeId ovfl := overflow(name)
if tt, ok := builtinIdToType[wireId]; ok { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
elemId = tt.(*sliceType).Elem state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
} else { }
elemId = dec.wireType[wireId].slice.Elem
case *reflect.MapType:
name = "element of " + name
keyId := dec.wireType[wireId].mapT.Key
elemId := dec.wireType[wireId].mapT.Elem
keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name)
if err != nil {
return nil, 0, err
} }
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
if err != nil { if err != nil {
...@@ -535,19 +598,32 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -535,19 +598,32 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} }
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) up := unsafe.Pointer(p)
if indir > 1 {
up = decIndirect(up, indir)
}
state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
} }
case *reflect.ArrayType: case *reflect.SliceType:
name = "element of " + name name = "element of " + name
elemId := dec.wireType[wireId].array.Elem if _, ok := t.Elem().(*reflect.Uint8Type); ok {
op = decUint8Array
break
}
var elemId typeId
if tt, ok := builtinIdToType[wireId]; ok {
elemId = tt.(*sliceType).Elem
} else {
elemId = dec.wireType[wireId].sliceT.Elem
}
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
} }
case *reflect.StructType: case *reflect.StructType:
...@@ -575,18 +651,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -575,18 +651,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
// Special cases // Special cases
wire := dec.wireType[wireId] wire := dec.wireType[wireId]
switch { switch {
case wire.array != nil: case wire.arrayT != nil:
elemId := wire.array.Elem elemId := wire.arrayT.Elem
elemOp, err := dec.decIgnoreOpFor(elemId) elemOp, err := dec.decIgnoreOpFor(elemId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = ignoreArray(state, elemOp, wire.array.Len) state.err = ignoreArray(state, elemOp, wire.arrayT.Len)
} }
case wire.slice != nil: case wire.mapT != nil:
elemId := wire.slice.Elem keyId := dec.wireType[wireId].mapT.Key
elemId := dec.wireType[wireId].mapT.Elem
keyOp, err := dec.decIgnoreOpFor(keyId)
if err != nil {
return nil, err
}
elemOp, err := dec.decIgnoreOpFor(elemId)
if err != nil {
return nil, err
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = ignoreMap(state, keyOp, elemOp)
}
case wire.sliceT != nil:
elemId := wire.sliceT.Elem
elemOp, err := dec.decIgnoreOpFor(elemId) elemOp, err := dec.decIgnoreOpFor(elemId)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -595,7 +686,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -595,7 +686,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
state.err = ignoreSlice(state, elemOp) state.err = ignoreSlice(state, elemOp)
} }
case wire.strct != nil: case wire.structT != nil:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getIgnoreEnginePtr(wireId) enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil { if err != nil {
...@@ -640,11 +731,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { ...@@ -640,11 +731,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return fw == tString return fw == tString
case *reflect.ArrayType: case *reflect.ArrayType:
wire, ok := dec.wireType[fw] wire, ok := dec.wireType[fw]
if !ok || wire.array == nil { if !ok || wire.arrayT == nil {
return false
}
array := wire.arrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
case *reflect.MapType:
wire, ok := dec.wireType[fw]
if !ok || wire.mapT == nil {
return false return false
} }
array := wire.array mapType := wire.mapT
return ok && t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) return dec.compatibleType(t.Key(), mapType.Key) && dec.compatibleType(t.Elem(), mapType.Elem)
case *reflect.SliceType: case *reflect.SliceType:
// Is it an array of bytes? // Is it an array of bytes?
et := t.Elem() et := t.Elem()
...@@ -656,7 +754,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { ...@@ -656,7 +754,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
if tt, ok := builtinIdToType[fw]; ok { if tt, ok := builtinIdToType[fw]; ok {
sw = tt.(*sliceType) sw = tt.(*sliceType)
} else { } else {
sw = dec.wireType[fw].slice sw = dec.wireType[fw].sliceT
} }
elem, _ := indirect(t.Elem()) elem, _ := indirect(t.Elem())
return sw != nil && dec.compatibleType(elem, sw.Elem) return sw != nil && dec.compatibleType(elem, sw.Elem)
...@@ -677,7 +775,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng ...@@ -677,7 +775,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
if !ok1 || !ok2 { if !ok1 || !ok2 {
return nil, errNotStruct return nil, errNotStruct
} }
wireStruct = w.strct wireStruct = w.structT
} }
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.field)) engine.instr = make([]decInstr, len(wireStruct.field))
...@@ -760,7 +858,7 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { ...@@ -760,7 +858,7 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
return err return err
} }
engine := *enginePtr engine := *enginePtr
if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].strct.field) > 0 { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 {
name := rt.Name() name := rt.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
} }
......
...@@ -22,7 +22,7 @@ const uint64Size = unsafe.Sizeof(uint64(0)) ...@@ -22,7 +22,7 @@ const uint64Size = unsafe.Sizeof(uint64(0))
type encoderState struct { type encoderState struct {
b *bytes.Buffer b *bytes.Buffer
err os.Error // error encountered during encoding. err os.Error // error encountered during encoding.
inArray bool // encoding an array element inArray bool // encoding an array element or map key/value pair
fieldnum int // the last field number written. fieldnum int // the last field number written.
buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation. buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
} }
...@@ -297,7 +297,7 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { ...@@ -297,7 +297,7 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
return state.err return state.err
} }
func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length int, elemIndir int) os.Error { func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error {
state := new(encoderState) state := new(encoderState)
state.b = b state.b = b
state.fieldnum = -1 state.fieldnum = -1
...@@ -319,6 +319,39 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length i ...@@ -319,6 +319,39 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length i
return state.err return state.err
} }
func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
for i := 0; i < indir && v != nil; i++ {
v = reflect.Indirect(v)
}
if v == nil {
state.err = os.ErrorString("gob: encodeMap: nil element")
return
}
op(nil, state, unsafe.Pointer(v.Addr()))
}
func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
state := new(encoderState)
state.b = b
state.fieldnum = -1
state.inArray = true
// Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for
// the iteration.
v := reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer((p))))
mv := reflect.Indirect(v).(*reflect.MapValue)
keys := mv.Keys()
encodeUint(state, uint64(len(keys)))
for _, key := range keys {
if state.err != nil {
break
}
encodeReflectValue(state, key, keyOp, keyIndir)
encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir)
}
return state.err
}
var encOpMap = map[reflect.Type]encOp{ var encOpMap = map[reflect.Type]encOp{
valueKind(false): encBool, valueKind(false): encBool,
valueKind(int(0)): encInt, valueKind(int(0)): encInt,
...@@ -344,7 +377,6 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -344,7 +377,6 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
typ, indir := indirect(rt) typ, indir := indirect(rt)
op, ok := encOpMap[reflect.Typeof(typ)] op, ok := encOpMap[reflect.Typeof(typ)]
if !ok { if !ok {
typ, _ := indirect(rt)
// Special cases // Special cases
switch t := typ.(type) { switch t := typ.(type) {
case *reflect.SliceType: case *reflect.SliceType:
...@@ -363,7 +395,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -363,7 +395,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
return return
} }
state.update(i) state.update(i)
state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), int(slice.Len), indir) state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
} }
case *reflect.ArrayType: case *reflect.ArrayType:
// True arrays have size in the type. // True arrays have size in the type.
...@@ -373,7 +405,20 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -373,7 +405,20 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
} }
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i) state.update(i)
state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir) state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
}
case *reflect.MapType:
keyOp, keyIndir, err := encOpFor(t.Key())
if err != nil {
return nil, 0, err
}
elemOp, elemIndir, err := encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i)
state.err = encodeMap(state.b, typ, uintptr(p), keyOp, elemOp, keyIndir, elemIndir)
} }
case *reflect.StructType: case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
......
...@@ -71,9 +71,8 @@ ...@@ -71,9 +71,8 @@
Structs, arrays and slices are also supported. Strings and arrays of bytes are Structs, arrays and slices are also supported. Strings and arrays of bytes are
supported with a special, efficient representation (see below). supported with a special, efficient representation (see below).
Maps are not supported yet, but they will be. Interfaces, functions, and channels Interfaces, functions, and channels cannot be sent in a gob. Attempting
cannot be sent in a gob. Attempting to encode a value that contains one will to encode a value that contains one will fail.
fail.
The rest of this comment documents the encoding, details that are not important The rest of this comment documents the encoding, details that are not important
for most users. Details are presented bottom-up. for most users. Details are presented bottom-up.
...@@ -263,10 +262,13 @@ func (enc *Encoder) sendType(origt reflect.Type) { ...@@ -263,10 +262,13 @@ func (enc *Encoder) sendType(origt reflect.Type) {
case *reflect.ArrayType: case *reflect.ArrayType:
// arrays must be sent so we know their lengths and element types. // arrays must be sent so we know their lengths and element types.
break break
case *reflect.MapType:
// maps must be sent so we know their lengths and key/value types.
break
case *reflect.StructType: case *reflect.StructType:
// structs must be sent so we know their fields. // structs must be sent so we know their fields.
break break
case *reflect.ChanType, *reflect.FuncType, *reflect.MapType, *reflect.InterfaceType: case *reflect.ChanType, *reflect.FuncType, *reflect.InterfaceType:
// Probably a bad field in a struct. // Probably a bad field in a struct.
enc.badType(rt) enc.badType(rt)
return return
......
...@@ -142,6 +142,31 @@ func (a *arrayType) safeString(seen map[typeId]bool) string { ...@@ -142,6 +142,31 @@ func (a *arrayType) safeString(seen map[typeId]bool) string {
func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
// Map type
type mapType struct {
commonType
Key typeId
Elem typeId
}
func newMapType(name string, key, elem gobType) *mapType {
m := &mapType{commonType{name: name}, key.id(), elem.id()}
setTypeId(m)
return m
}
func (m *mapType) safeString(seen map[typeId]bool) string {
if seen[m._id] {
return m.name
}
seen[m._id] = true
key := m.Key.gobType().safeString(seen)
elem := m.Elem.gobType().safeString(seen)
return fmt.Sprintf("map[%s]%s", key, elem)
}
func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
// Slice type // Slice type
type sliceType struct { type sliceType struct {
commonType commonType
...@@ -239,6 +264,17 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -239,6 +264,17 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
} }
return newArrayType(name, gt, t.Len()), nil return newArrayType(name, gt, t.Len()), nil
case *reflect.MapType:
kt, err := getType("", t.Key())
if err != nil {
return nil, err
}
vt, err := getType("", t.Elem())
if err != nil {
return nil, err
}
return newMapType(name, kt, vt), nil
case *reflect.SliceType: case *reflect.SliceType:
// []byte == []uint8 is a special case // []byte == []uint8 is a special case
if _, ok := t.Elem().(*reflect.Uint8Type); ok { if _, ok := t.Elem().(*reflect.Uint8Type); ok {
...@@ -330,16 +366,18 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { ...@@ -330,16 +366,18 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
// using the gob rules for sending a structure, except that we assume the // using the gob rules for sending a structure, except that we assume the
// ids for wireType and structType are known. The relevant pieces // ids for wireType and structType are known. The relevant pieces
// are built in encode.go's init() function. // are built in encode.go's init() function.
// To maintain binary compatibility, if you extend this type, always put
// the new fields last.
type wireType struct { type wireType struct {
array *arrayType arrayT *arrayType
slice *sliceType sliceT *sliceType
strct *structType structT *structType
mapT *mapType
} }
func (w *wireType) name() string { func (w *wireType) name() string {
if w.strct != nil { if w.structT != nil {
return w.strct.name return w.structT.name
} }
return "unknown" return "unknown"
} }
...@@ -370,14 +408,16 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { ...@@ -370,14 +408,16 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
t := info.id.gobType() t := info.id.gobType()
switch typ := rt.(type) { switch typ := rt.(type) {
case *reflect.ArrayType: case *reflect.ArrayType:
info.wire = &wireType{array: t.(*arrayType)} info.wire = &wireType{arrayT: t.(*arrayType)}
case *reflect.MapType:
info.wire = &wireType{mapT: t.(*mapType)}
case *reflect.SliceType: case *reflect.SliceType:
// []byte == []uint8 is a special case handled separately // []byte == []uint8 is a special case handled separately
if _, ok := typ.Elem().(*reflect.Uint8Type); !ok { if _, ok := typ.Elem().(*reflect.Uint8Type); !ok {
info.wire = &wireType{slice: t.(*sliceType)} info.wire = &wireType{sliceT: t.(*sliceType)}
} }
case *reflect.StructType: case *reflect.StructType:
info.wire = &wireType{strct: t.(*structType)} info.wire = &wireType{structT: t.(*structType)}
} }
typeInfoMap[rt] = info typeInfoMap[rt] = info
} }
......
...@@ -105,6 +105,26 @@ func TestSliceType(t *testing.T) { ...@@ -105,6 +105,26 @@ func TestSliceType(t *testing.T) {
} }
} }
func TestMapType(t *testing.T) {
var m map[string]int
mapStringInt := getTypeUnlocked("map", reflect.Typeof(m))
var newm map[string]int
newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm))
if mapStringInt != newMapStringInt {
t.Errorf("second registration of map[string]int creates new type")
}
var b map[string]bool
mapStringBool := getTypeUnlocked("", reflect.Typeof(b))
if mapStringBool == mapStringInt {
t.Errorf("registration of map[string]bool creates same type as map[string]int")
}
str := mapStringBool.string()
expected := "map[string]bool"
if str != expected {
t.Errorf("map printed as %q; expected %q", str, expected)
}
}
type Bar struct { type Bar struct {
x string x string
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment