Commit c91daefb authored by Rob Pike's avatar Rob Pike

gob: beginning of support for GobEncoder/GobDecoder interfaces.

This allows a data item that can marshal itself to be transmitted by its
own encoding, enabling some types to be handled that cannot be
normally, plus providing a way to use gobs on data with unexported
fields.

In this CL, the necessary methods are protected by leading _, so only
package gob can use the facilities (in its tests, of course); this
code is not ready for real use yet.  I could be talked into enabling
it for experimentation, though.  The main drawback is that the
methods must be implemented by the actual type passed through,
not by an indirection from it.  For instance, if *T implements
GobEncoder, you must send a *T, not a T.  This will be addressed
in due course.

Also there is improved commentary and a couple of unrelated
minor bug fixes.

R=rsc
CC=golang-dev
https://golang.org/cl/4243056
parent 7b563be5
...@@ -303,7 +303,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -303,7 +303,7 @@ func TestScalarEncInstructions(t *testing.T) {
} }
} }
func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) { func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) {
defer testError(t) defer testError(t)
v := int(state.decodeUint()) v := int(state.decodeUint())
if v+state.fieldnum != 6 { if v+state.fieldnum != 6 {
...@@ -313,7 +313,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un ...@@ -313,7 +313,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
state.fieldnum = 6 state.fieldnum = 6
} }
func newDecodeStateFromData(data []byte) *decodeState { func newDecodeStateFromData(data []byte) *decoderState {
b := bytes.NewBuffer(data) b := bytes.NewBuffer(data)
state := newDecodeState(nil, b) state := newDecodeState(nil, b)
state.fieldnum = -1 state.fieldnum = -1
......
...@@ -155,6 +155,16 @@ func (deb *debugger) dump(format string, args ...interface{}) { ...@@ -155,6 +155,16 @@ func (deb *debugger) dump(format string, args ...interface{}) {
// Debug prints a human-readable representation of the gob data read from r. // Debug prints a human-readable representation of the gob data read from r.
func Debug(r io.Reader) { func Debug(r io.Reader) {
err := debug(r)
if err != nil {
fmt.Fprintf(os.Stderr, "gob debug: %s\n", err)
}
}
// debug implements Debug, but catches panics and returns
// them as errors to be printed by Debug.
func debug(r io.Reader) (err os.Error) {
defer catchError(&err)
fmt.Fprintln(os.Stderr, "Start of debugging") fmt.Fprintln(os.Stderr, "Start of debugging")
deb := &debugger{ deb := &debugger{
r: newPeekReader(r), r: newPeekReader(r),
...@@ -166,6 +176,7 @@ func Debug(r io.Reader) { ...@@ -166,6 +176,7 @@ func Debug(r io.Reader) {
deb.remainingKnown = true deb.remainingKnown = true
} }
deb.gobStream() deb.gobStream()
return
} }
// note that we've consumed some bytes // note that we've consumed some bytes
...@@ -386,11 +397,15 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) { ...@@ -386,11 +397,15 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) {
// Field number 1 is type Id of key // Field number 1 is type Id of key
deb.delta(1) deb.delta(1)
keyId := deb.typeId() keyId := deb.typeId()
wire.SliceT = &sliceType{com, id}
// Field number 2 is type Id of elem // Field number 2 is type Id of elem
deb.delta(1) deb.delta(1)
elemId := deb.typeId() elemId := deb.typeId()
wire.MapT = &mapType{com, keyId, elemId} wire.MapT = &mapType{com, keyId, elemId}
case 4: // GobEncoder type, one field of {{Common}}
// Field number 0 is CommonType
deb.delta(1)
com := deb.common()
wire.GobEncoderT = &gobEncoderType{com}
default: default:
errorf("bad field in type %d", fieldNum) errorf("bad field in type %d", fieldNum)
} }
...@@ -507,6 +522,8 @@ func (deb *debugger) printWireType(indent tab, wire *wireType) { ...@@ -507,6 +522,8 @@ func (deb *debugger) printWireType(indent tab, wire *wireType) {
for i, field := range wire.StructT.Field { for i, field := range wire.StructT.Field {
fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\tid=%d\n", indent+1, i, field.Name, field.Id) fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\tid=%d\n", indent+1, i, field.Name, field.Id)
} }
case wire.GobEncoderT != nil:
deb.printCommonType(indent, "GobEncoder", &wire.GobEncoderT.CommonType)
} }
indent-- indent--
fmt.Fprintf(os.Stderr, "%s}\n", indent) fmt.Fprintf(os.Stderr, "%s}\n", indent)
...@@ -538,6 +555,8 @@ func (deb *debugger) fieldValue(indent tab, id typeId) { ...@@ -538,6 +555,8 @@ func (deb *debugger) fieldValue(indent tab, id typeId) {
deb.sliceValue(indent, wire) deb.sliceValue(indent, wire)
case wire.StructT != nil: case wire.StructT != nil:
deb.structValue(indent, id) deb.structValue(indent, id)
case wire.GobEncoderT != nil:
deb.gobEncoderValue(indent, id)
default: default:
panic("bad wire type for field") panic("bad wire type for field")
} }
...@@ -654,3 +673,17 @@ func (deb *debugger) structValue(indent tab, id typeId) { ...@@ -654,3 +673,17 @@ func (deb *debugger) structValue(indent tab, id typeId) {
fmt.Fprintf(os.Stderr, "%s} // end %s struct\n", indent, id.name()) fmt.Fprintf(os.Stderr, "%s} // end %s struct\n", indent, id.name())
deb.dump(">> End of struct value of type %d %q", id, id.name()) deb.dump(">> End of struct value of type %d %q", id, id.name())
} }
// GobEncoderValue:
// uint(n) byte*n
func (deb *debugger) gobEncoderValue(indent tab, id typeId) {
len := deb.uint64()
deb.dump("GobEncoder value of %q id=%d, length %d\n", id.name(), id, len)
fmt.Fprintf(os.Stderr, "%s%s (implements GobEncoder)\n", indent, id.name())
data := make([]byte, len)
_, err := deb.r.Read(data)
if err != nil {
errorf("gobEncoder data read: %s", err)
}
fmt.Fprintf(os.Stderr, "%s[% .2x]\n", indent+1, data)
}
...@@ -24,9 +24,9 @@ var ( ...@@ -24,9 +24,9 @@ var (
errRange = os.ErrorString("gob: internal error: field numbers out of bounds") errRange = os.ErrorString("gob: internal error: field numbers out of bounds")
) )
// The execution state of an instance of the decoder. A new state // decoderState is the execution state of an instance of the decoder. A new state
// is created for nested objects. // is created for nested objects.
type decodeState struct { type decoderState struct {
dec *Decoder dec *Decoder
// The buffer is stored with an extra indirection because it may be replaced // The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value). // if we load a type during decode (when reading an interface value).
...@@ -37,8 +37,8 @@ type decodeState struct { ...@@ -37,8 +37,8 @@ type decodeState struct {
// We pass the bytes.Buffer separately for easier testing of the infrastructure // We pass the bytes.Buffer separately for easier testing of the infrastructure
// without requiring a full Decoder. // without requiring a full Decoder.
func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState { func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decoderState {
d := new(decodeState) d := new(decoderState)
d.dec = dec d.dec = dec
d.b = buf d.b = buf
d.buf = make([]byte, uint64Size) d.buf = make([]byte, uint64Size)
...@@ -85,7 +85,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Erro ...@@ -85,7 +85,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Erro
// decodeUint reads an encoded unsigned integer from state.r. // decodeUint reads an encoded unsigned integer from state.r.
// Does not check for overflow. // Does not check for overflow.
func (state *decodeState) decodeUint() (x uint64) { func (state *decoderState) decodeUint() (x uint64) {
b, err := state.b.ReadByte() b, err := state.b.ReadByte()
if err != nil { if err != nil {
error(err) error(err)
...@@ -112,7 +112,7 @@ func (state *decodeState) decodeUint() (x uint64) { ...@@ -112,7 +112,7 @@ func (state *decodeState) decodeUint() (x uint64) {
// decodeInt reads an encoded signed integer from state.r. // decodeInt reads an encoded signed integer from state.r.
// Does not check for overflow. // Does not check for overflow.
func (state *decodeState) decodeInt() int64 { func (state *decoderState) decodeInt() int64 {
x := state.decodeUint() x := state.decodeUint()
if x&1 != 0 { if x&1 != 0 {
return ^int64(x >> 1) return ^int64(x >> 1)
...@@ -120,7 +120,8 @@ func (state *decodeState) decodeInt() int64 { ...@@ -120,7 +120,8 @@ func (state *decodeState) decodeInt() int64 {
return int64(x >> 1) return int64(x >> 1)
} }
type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer) // decOp is the signature of a decoding operator for a given type.
type decOp func(i *decInstr, state *decoderState, p unsafe.Pointer)
// The 'instructions' of the decoding machine // The 'instructions' of the decoding machine
type decInstr struct { type decInstr struct {
...@@ -150,26 +151,31 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { ...@@ -150,26 +151,31 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
return p return p
} }
func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) { // ignoreUint discards a uint value with no destination.
func ignoreUint(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.decodeUint() state.decodeUint()
} }
func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) { // ignoreTwoUints discards a uint value with no destination. It's used to skip
// complex values.
func ignoreTwoUints(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.decodeUint() state.decodeUint()
state.decodeUint() state.decodeUint()
} }
func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { // decBool decodes a uiint and stores it as a boolean through p.
func decBool(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool))
} }
p = *(*unsafe.Pointer)(p) p = *(*unsafe.Pointer)(p)
} }
*(*bool)(p) = state.decodeInt() != 0 *(*bool)(p) = state.decodeUint() != 0
} }
func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { // decInt8 decodes an integer and stores it as an int8 through p.
func decInt8(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8))
...@@ -184,7 +190,8 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -184,7 +190,8 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { // decUint8 decodes an unsigned integer and stores it as a uint8 through p.
func decUint8(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8))
...@@ -199,7 +206,8 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -199,7 +206,8 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { // decInt16 decodes an integer and stores it as an int16 through p.
func decInt16(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16))
...@@ -214,7 +222,8 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -214,7 +222,8 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { // decUint16 decodes an unsigned integer and stores it as a uint16 through p.
func decUint16(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16))
...@@ -229,7 +238,8 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -229,7 +238,8 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { // decInt32 decodes an integer and stores it as an int32 through p.
func decInt32(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32))
...@@ -244,7 +254,8 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -244,7 +254,8 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { // decUint32 decodes an unsigned integer and stores it as a uint32 through p.
func decUint32(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32))
...@@ -259,7 +270,8 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -259,7 +270,8 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { // decInt64 decodes an integer and stores it as an int64 through p.
func decInt64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64))
...@@ -269,7 +281,8 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -269,7 +281,8 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*int64)(p) = int64(state.decodeInt()) *(*int64)(p) = int64(state.decodeInt())
} }
func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) { // decUint64 decodes an unsigned integer and stores it as a uint64 through p.
func decUint64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64))
...@@ -294,7 +307,9 @@ func floatFromBits(u uint64) float64 { ...@@ -294,7 +307,9 @@ func floatFromBits(u uint64) float64 {
return math.Float64frombits(v) return math.Float64frombits(v)
} }
func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { // storeFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
// number, and stores it through p. It's a helper function for float32 and complex64.
func storeFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) {
v := floatFromBits(state.decodeUint()) v := floatFromBits(state.decodeUint())
av := v av := v
if av < 0 { if av < 0 {
...@@ -308,7 +323,9 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -308,7 +323,9 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
} }
} }
func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { // decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
// number, and stores it through p.
func decFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32))
...@@ -318,7 +335,9 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -318,7 +335,9 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
storeFloat32(i, state, p) storeFloat32(i, state, p)
} }
func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { // decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point
// number, and stores it through p.
func decFloat64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64))
...@@ -328,8 +347,10 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -328,8 +347,10 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*float64)(p) = floatFromBits(uint64(state.decodeUint())) *(*float64)(p) = floatFromBits(uint64(state.decodeUint()))
} }
// Complex numbers are just a pair of floating-point numbers, real part first. // decComplex64 decodes a pair of unsigned integers, treats them as a
func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { // pair of floating point numbers, and stores them as a complex64 through p.
// The real part comes first.
func decComplex64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64))
...@@ -340,7 +361,10 @@ func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -340,7 +361,10 @@ func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) {
storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float32(0))))) storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float32(0)))))
} }
func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { // decComplex128 decodes a pair of unsigned integers, treats them as a
// pair of floating point numbers, and stores them as a complex128 through p.
// The real part comes first.
func decComplex128(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128))
...@@ -352,8 +376,10 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -352,8 +376,10 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*complex128)(p) = complex(real, imag) *(*complex128)(p) = complex(real, imag)
} }
// decUint8Array decodes byte array and stores through p a slice header
// describing the data.
// uint8 arrays are encoded as an unsigned count followed by the raw bytes. // uint8 arrays are encoded as an unsigned count followed by the raw bytes.
func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8))
...@@ -365,8 +391,10 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -365,8 +391,10 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*[]uint8)(p) = b *(*[]uint8)(p) = b
} }
// decString decodes byte array and stores through p a string header
// describing the data.
// Strings are encoded as an unsigned count followed by the raw bytes. // Strings are encoded as an unsigned count followed by the raw bytes.
func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { func decString(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)) *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte))
...@@ -378,7 +406,8 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -378,7 +406,8 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*string)(p) = string(b) *(*string)(p) = string(b)
} }
func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { // ignoreUint8Array skips over the data for a byte slice value with no destination.
func ignoreUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) {
b := make([]byte, state.decodeUint()) b := make([]byte, state.decodeUint())
state.b.Read(b) state.b.Read(b)
} }
...@@ -409,8 +438,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { ...@@ -409,8 +438,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
return *(*uintptr)(up) return *(*uintptr)(up)
} }
// decodeSingle decodes a top-level value that is not a struct and stores it through p.
// Such values are preceded by a zero, making them have the memory layout of a
// struct field (although with an illegal field number).
func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
p = allocate(ut.base, p, ut.indir) indir := ut.indir
if ut.isGobDecoder {
indir = int(ut.decIndir)
}
p = allocate(ut.base, p, indir)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField state.fieldnum = singletonField
basep := p basep := p
...@@ -427,6 +463,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) ...@@ -427,6 +463,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
return nil return nil
} }
// decodeSingle decodes a top-level struct and stores it through p.
// Indir is for the value, not the type. At the time of the call it may // Indir is for the value, not the type. At the time of the call it may
// differ from ut.indir, which was computed when the engine was built. // differ from ut.indir, which was computed when the engine was built.
// This state cannot arise for decodeSingle, which is called directly // This state cannot arise for decodeSingle, which is called directly
...@@ -460,6 +497,7 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, ...@@ -460,6 +497,7 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
return nil return nil
} }
// ignoreStruct discards the data for a struct with no destination.
func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1 state.fieldnum = -1
...@@ -482,6 +520,8 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { ...@@ -482,6 +520,8 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
return nil return nil
} }
// ignoreSingle discards the data for a top-level non-struct value with no
// destination. It's used when calling Decode with a nil value.
func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField state.fieldnum = singletonField
...@@ -494,7 +534,8 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { ...@@ -494,7 +534,8 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
return nil return nil
} }
func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { // decodeArrayHelper does the work for decoding arrays and slices.
func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) {
instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl}
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
up := unsafe.Pointer(p) up := unsafe.Pointer(p)
...@@ -506,7 +547,10 @@ func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decO ...@@ -506,7 +547,10 @@ func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decO
} }
} }
func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { // decodeArray decodes an array and stores it through p, that is, p points to the zeroth element.
// The length is an unsigned integer preceding the elements. Even though the length is redundant
// (it's part of the type), it's a useful check and is included in the encoding.
func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) {
if indir > 0 { if indir > 0 {
p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect
} }
...@@ -516,7 +560,9 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u ...@@ -516,7 +560,9 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u
dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
} }
func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { // decodeIntoValue is a helper for map decoding. Since maps are decoded using reflection,
// unlike the other items we can't use a pointer directly.
func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
instr := &decInstr{op, 0, indir, 0, ovfl} instr := &decInstr{op, 0, indir, 0, ovfl}
up := unsafe.Pointer(v.UnsafeAddr()) up := unsafe.Pointer(v.UnsafeAddr())
if indir > 1 { if indir > 1 {
...@@ -526,7 +572,11 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o ...@@ -526,7 +572,11 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o
return v return v
} }
func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { // decodeMap decodes a map and stores its header through p.
// Maps are encoded as a length followed by key:value pairs.
// Because the internals of maps are not visible to us, we must
// use reflection rather than pointer magic.
func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) {
if indir > 0 { if indir > 0 {
p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect
} }
...@@ -538,7 +588,7 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp ...@@ -538,7 +588,7 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp
// Maps cannot be accessed by moving addresses around the way // Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for // that slices etc. can. We must recover a full reflection value for
// the iteration. // the iteration.
v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))).(*reflect.MapValue)
n := int(state.decodeUint()) n := int(state.decodeUint())
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
...@@ -547,21 +597,24 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp ...@@ -547,21 +597,24 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp
} }
} }
func (dec *Decoder) ignoreArrayHelper(state *decodeState, elemOp decOp, length int) { // ignoreArrayHelper does the work for discarding arrays and slices.
func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) {
instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
elemOp(instr, state, nil) elemOp(instr, state, nil)
} }
} }
func (dec *Decoder) ignoreArray(state *decodeState, elemOp decOp, length int) { // ignoreArray discards the data for an array value with no destination.
func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) {
if n := state.decodeUint(); n != uint64(length) { if n := state.decodeUint(); n != uint64(length) {
errorf("gob: length mismatch in ignoreArray") errorf("gob: length mismatch in ignoreArray")
} }
dec.ignoreArrayHelper(state, elemOp, length) dec.ignoreArrayHelper(state, elemOp, length)
} }
func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { // ignoreMap discards the data for a map value with no destination.
func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
n := int(state.decodeUint()) n := int(state.decodeUint())
keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
...@@ -571,7 +624,9 @@ func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { ...@@ -571,7 +624,9 @@ func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) {
} }
} }
func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { // decodeSlice decodes a slice and stores the slice header through p.
// Slices are encoded as an unsigned length followed by the elements.
func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) {
n := int(uintptr(state.decodeUint())) n := int(uintptr(state.decodeUint()))
if indir > 0 { if indir > 0 {
up := unsafe.Pointer(p) up := unsafe.Pointer(p)
...@@ -590,7 +645,8 @@ func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p u ...@@ -590,7 +645,8 @@ func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p u
dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl)
} }
func (dec *Decoder) ignoreSlice(state *decodeState, elemOp decOp) { // ignoreSlice skips over the data for a slice value with no destination.
func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) {
dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint()))
} }
...@@ -609,9 +665,10 @@ func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) { ...@@ -609,9 +665,10 @@ func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) {
ivalue.Set(value) ivalue.Set(value)
} }
// decodeInterface receives the name of a concrete type followed by its value. // decodeInterface decodes an interface value and stores it through p.
// Interfaces are encoded as the name of a concrete type followed by a value.
// If the name is empty, the value is nil and no value is sent. // If the name is empty, the value is nil and no value is sent.
func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) { func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderState, p uintptr, indir int) {
// Create an interface reflect.Value. We need one even for the nil case. // Create an interface reflect.Value. We need one even for the nil case.
ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue) ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue)
// Read the name of the concrete type. // Read the name of the concrete type.
...@@ -655,7 +712,8 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt ...@@ -655,7 +712,8 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
*(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
} }
func (dec *Decoder) ignoreInterface(state *decodeState) { // ignoreInterface discards the data for an interface value with no destination.
func (dec *Decoder) ignoreInterface(state *decoderState) {
// Read the name of the concrete type. // Read the name of the concrete type.
b := make([]byte, state.decodeUint()) b := make([]byte, state.decodeUint())
_, err := state.b.Read(b) _, err := state.b.Read(b)
...@@ -670,6 +728,32 @@ func (dec *Decoder) ignoreInterface(state *decodeState) { ...@@ -670,6 +728,32 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
state.b.Next(int(state.decodeUint())) state.b.Next(int(state.decodeUint()))
} }
// decodeGobDecoder decodes something implementing the GobDecoder interface.
// The data is encoded as a byte slice.
func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value, index int) {
// Read the bytes for the value.
b := make([]byte, state.decodeUint())
_, err := state.b.Read(b)
if err != nil {
error(err)
}
// We know it's a GobDecoder, so just call the method directly.
err = v.Interface().(_GobDecoder)._GobDecode(b)
if err != nil {
error(err)
}
}
// ignoreGobDecoder discards the data for a GobDecoder value with no destination.
func (dec *Decoder) ignoreGobDecoder(state *decoderState) {
// Read the bytes for the value.
b := make([]byte, state.decodeUint())
_, err := state.b.Read(b)
if err != nil {
error(err)
}
}
// Index by Go types. // Index by Go types.
var decOpTable = [...]decOp{ var decOpTable = [...]decOp{
reflect.Bool: decBool, reflect.Bool: decBool,
...@@ -699,10 +783,14 @@ var decIgnoreOpMap = map[typeId]decOp{ ...@@ -699,10 +783,14 @@ var decIgnoreOpMap = map[typeId]decOp{
tComplex: ignoreTwoUints, tComplex: ignoreTwoUints,
} }
// Return the decoding op for the base type under rt and // decOpFor returns the decoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
ut := userType(rt) ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing.
if ut.isGobDecoder {
return dec.gobDecodeOpFor(ut)
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T). // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building. // Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil { if opPtr := inProgress[rt]; opPtr != nil {
...@@ -724,7 +812,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg ...@@ -724,7 +812,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
elemId := dec.wireType[wireId].ArrayT.Elem elemId := dec.wireType[wireId].ArrayT.Elem
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
} }
...@@ -735,7 +823,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg ...@@ -735,7 +823,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress) keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
up := unsafe.Pointer(p) up := unsafe.Pointer(p)
state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl) state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
} }
...@@ -754,17 +842,17 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg ...@@ -754,17 +842,17 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
} }
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
} }
case *reflect.StructType: case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getDecEnginePtr(wireId, typ) enginePtr, err := dec.getDecEnginePtr(wireId, userType(typ))
if err != nil { if err != nil {
error(err) error(err)
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs. // indirect through enginePtr to delay evaluation for recursive structs.
err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir) err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir)
if err != nil { if err != nil {
...@@ -772,8 +860,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg ...@@ -772,8 +860,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
} }
} }
case *reflect.InterfaceType: case *reflect.InterfaceType:
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
dec.decodeInterface(t, state, uintptr(p), i.indir) state.dec.decodeInterface(t, state, uintptr(p), i.indir)
} }
} }
} }
...@@ -783,15 +871,15 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg ...@@ -783,15 +871,15 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
return &op, indir return &op, indir
} }
// Return the decoding op for a field that has no destination. // decIgnoreOpFor returns the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
op, ok := decIgnoreOpMap[wireId] op, ok := decIgnoreOpMap[wireId]
if !ok { if !ok {
if wireId == tInterface { if wireId == tInterface {
// Special case because it's a method: the ignored item might // Special case because it's a method: the ignored item might
// define types and we need to record their state in the decoder. // define types and we need to record their state in the decoder.
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
dec.ignoreInterface(state) state.dec.ignoreInterface(state)
} }
return op return op
} }
...@@ -803,7 +891,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -803,7 +891,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
case wire.ArrayT != nil: case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem elemId := wire.ArrayT.Elem
elemOp := dec.decIgnoreOpFor(elemId) elemOp := dec.decIgnoreOpFor(elemId)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len) state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len)
} }
...@@ -812,14 +900,14 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -812,14 +900,14 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
elemId := dec.wireType[wireId].MapT.Elem elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decIgnoreOpFor(keyId) keyOp := dec.decIgnoreOpFor(keyId)
elemOp := dec.decIgnoreOpFor(elemId) elemOp := dec.decIgnoreOpFor(elemId)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreMap(state, keyOp, elemOp) state.dec.ignoreMap(state, keyOp, elemOp)
} }
case wire.SliceT != nil: case wire.SliceT != nil:
elemId := wire.SliceT.Elem elemId := wire.SliceT.Elem
elemOp := dec.decIgnoreOpFor(elemId) elemOp := dec.decIgnoreOpFor(elemId)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreSlice(state, elemOp) state.dec.ignoreSlice(state, elemOp)
} }
...@@ -829,10 +917,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -829,10 +917,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
if err != nil { if err != nil {
error(err) error(err)
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs // indirect through enginePtr to delay evaluation for recursive structs
state.dec.ignoreStruct(*enginePtr) state.dec.ignoreStruct(*enginePtr)
} }
case wire.GobEncoderT != nil:
op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreGobDecoder(state)
}
} }
} }
if op == nil { if op == nil {
...@@ -841,16 +934,56 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -841,16 +934,56 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
return op return op
} }
// Are these two gob Types compatible? // gobDecodeOpFor returns the op for a type that is known to implement
// Answers the question for basic types, arrays, and slices. // GobDecoder.
func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) {
rt := ut.user
if ut.decIndir != 0 {
errorf("gob: TODO: can't handle indirection to reach GobDecoder")
}
index := -1
for i := 0; i < rt.NumMethod(); i++ {
if rt.Method(i).Name == gobDecodeMethodName {
index = i
break
}
}
if index < 0 {
panic("can't find GobDecode method")
}
var op decOp
op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
// Allocate the underlying data, but hold on to the address we have,
// since it's known to be the receiver's address.
// TODO: fix this up when decIndir can be non-zero.
allocate(ut.base, uintptr(p), ut.indir)
v := reflect.NewValue(unsafe.Unreflect(rt, p))
state.dec.decodeGobDecoder(state, v, index)
}
return &op, int(ut.decIndir)
}
// compatibleType asks: Are these two gob Types compatible?
// Answers the question for basic types, arrays, maps and slices, plus
// GobEncoder/Decoder pairs.
// Structs are considered ok; fields will be checked later. // Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool { func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok { if rhs, ok := inProgress[fr]; ok {
return rhs == fw return rhs == fw
} }
inProgress[fr] = fw inProgress[fr] = fw
fr = userType(fr).base ut := userType(fr)
switch t := fr.(type) { wire, ok := dec.wireType[fw]
// If fr is a GobDecoder, the wire type must be GobEncoder.
// And if fr is not a GobDecoder, the wire type must not be either.
if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct.
return false
}
if ut.isGobDecoder { // This test trumps all others.
return true
}
switch t := ut.base.(type) {
default: default:
// chan, etc: cannot handle. // chan, etc: cannot handle.
return false return false
...@@ -869,14 +1002,12 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re ...@@ -869,14 +1002,12 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
case *reflect.InterfaceType: case *reflect.InterfaceType:
return fw == tInterface return fw == tInterface
case *reflect.ArrayType: case *reflect.ArrayType:
wire, ok := dec.wireType[fw]
if !ok || wire.ArrayT == nil { if !ok || wire.ArrayT == nil {
return false return false
} }
array := wire.ArrayT array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case *reflect.MapType: case *reflect.MapType:
wire, ok := dec.wireType[fw]
if !ok || wire.MapT == nil { if !ok || wire.MapT == nil {
return false return false
} }
...@@ -911,8 +1042,13 @@ func (dec *Decoder) typeString(remoteId typeId) string { ...@@ -911,8 +1042,13 @@ func (dec *Decoder) typeString(remoteId typeId) string {
return dec.wireType[remoteId].string() return dec.wireType[remoteId].string()
} }
// compileSingle compiles the decoder engine for a non-struct top-level value, including
func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { // GobDecoders.
func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) {
rt := ut.base
if ut.isGobDecoder {
rt = ut.user
}
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do name := rt.String() // best we can do
...@@ -926,6 +1062,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec ...@@ -926,6 +1062,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
return return
} }
// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded.
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) { func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) {
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item engine.instr = make([]decInstr, 1) // one item
...@@ -936,16 +1073,19 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err ...@@ -936,16 +1073,19 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err
return return
} }
// Is this an exported - upper case - name? // isExported reports whether this is an exported - upper case - name.
func isExported(name string) bool { func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name) rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune) return unicode.IsUpper(rune)
} }
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { // compileDec compiles the decoder engine for a value. If the value is not a struct,
// it calls out to compileSingle.
func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) {
rt := ut.base
srt, ok := rt.(*reflect.StructType) srt, ok := rt.(*reflect.StructType)
if !ok { if !ok || ut.isGobDecoder {
return dec.compileSingle(remoteId, rt) return dec.compileSingle(remoteId, ut)
} }
var wireStruct *structType var wireStruct *structType
// Builtin types can come from global pool; the rest must be defined by the decoder. // Builtin types can come from global pool; the rest must be defined by the decoder.
...@@ -990,7 +1130,9 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng ...@@ -990,7 +1130,9 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
return return
} }
func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) { // getDecEnginePtr returns the engine for the specified type.
func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err os.Error) {
rt := ut.base
decoderMap, ok := dec.decoderCache[rt] decoderMap, ok := dec.decoderCache[rt]
if !ok { if !ok {
decoderMap = make(map[typeId]**decEngine) decoderMap = make(map[typeId]**decEngine)
...@@ -1000,7 +1142,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr ...@@ -1000,7 +1142,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr
// To handle recursive types, mark this engine as underway before compiling. // To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine) enginePtr = new(*decEngine)
decoderMap[remoteId] = enginePtr decoderMap[remoteId] = enginePtr
*enginePtr, err = dec.compileDec(remoteId, rt) *enginePtr, err = dec.compileDec(remoteId, ut)
if err != nil { if err != nil {
decoderMap[remoteId] = nil, false decoderMap[remoteId] = nil, false
} }
...@@ -1008,11 +1150,12 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr ...@@ -1008,11 +1150,12 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr
return return
} }
// When ignoring struct data, in effect we compile it into this type // emptyStruct is the type we compile into when ignoring a struct value.
type emptyStruct struct{} type emptyStruct struct{}
var emptyStructType = reflect.Typeof(emptyStruct{}) var emptyStructType = reflect.Typeof(emptyStruct{})
// getDecEnginePtr returns the engine for the specified type when the value is to be discarded.
func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
var ok bool var ok bool
if enginePtr, ok = dec.ignorerCache[wireId]; !ok { if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
...@@ -1021,7 +1164,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er ...@@ -1021,7 +1164,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
dec.ignorerCache[wireId] = enginePtr dec.ignorerCache[wireId] = enginePtr
wire := dec.wireType[wireId] wire := dec.wireType[wireId]
if wire != nil && wire.StructT != nil { if wire != nil && wire.StructT != nil {
*enginePtr, err = dec.compileDec(wireId, emptyStructType) *enginePtr, err = dec.compileDec(wireId, userType(emptyStructType))
} else { } else {
*enginePtr, err = dec.compileIgnoreSingle(wireId) *enginePtr, err = dec.compileIgnoreSingle(wireId)
} }
...@@ -1032,6 +1175,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er ...@@ -1032,6 +1175,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
return return
} }
// decodeValue decodes the data stream representing a value and stores it in val.
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) { func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
defer catchError(&err) defer catchError(&err)
// If the value is nil, it means we should just ignore this item. // If the value is nil, it means we should just ignore this item.
...@@ -1042,12 +1186,18 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) ...@@ -1042,12 +1186,18 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error)
ut := userType(val.Type()) ut := userType(val.Type())
base := ut.base base := ut.base
indir := ut.indir indir := ut.indir
enginePtr, err := dec.getDecEnginePtr(wireId, base) if ut.isGobDecoder {
indir = int(ut.decIndir)
if indir != 0 {
errorf("TODO: can't handle indirection in GobDecoder value")
}
}
enginePtr, err := dec.getDecEnginePtr(wireId, ut)
if err != nil { if err != nil {
return err return err
} }
engine := *enginePtr engine := *enginePtr
if st, ok := base.(*reflect.StructType); ok { if st, ok := base.(*reflect.StructType); ok && !ut.isGobDecoder {
if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 {
name := base.Name() name := base.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
...@@ -1057,6 +1207,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) ...@@ -1057,6 +1207,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error)
return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr()))
} }
// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it.
func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error { func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error {
enginePtr, err := dec.getIgnoreEnginePtr(wireId) enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil { if err != nil {
......
...@@ -21,7 +21,7 @@ type Decoder struct { ...@@ -21,7 +21,7 @@ type Decoder struct {
wireType map[typeId]*wireType // map from remote ID to local description wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects ignorerCache map[typeId]**decEngine // ditto for ignored objects
countState *decodeState // reads counts from wire countState *decoderState // reads counts from wire
countBuf []byte // used for decoding integers while parsing messages countBuf []byte // used for decoding integers while parsing messages
tmp []byte // temporary storage for i/o; saves reallocating tmp []byte // temporary storage for i/o; saves reallocating
err os.Error err os.Error
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
const uint64Size = unsafe.Sizeof(uint64(0)) const uint64Size = unsafe.Sizeof(uint64(0))
// The global execution state of an instance of the encoder. // encoderState is the global execution state of an instance of the encoder.
// Field numbers are delta encoded and always increase. The field // Field numbers are delta encoded and always increase. The field
// number is initialized to -1 so 0 comes out as delta(1). A delta of // number is initialized to -1 so 0 comes out as delta(1). A delta of
// 0 terminates the structure. // 0 terminates the structure.
...@@ -72,6 +72,7 @@ func (state *encoderState) encodeInt(i int64) { ...@@ -72,6 +72,7 @@ func (state *encoderState) encodeInt(i int64) {
state.encodeUint(uint64(x)) state.encodeUint(uint64(x))
} }
// encOp is the signature of an encoding operator for a given type.
type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer) type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer)
// The 'instructions' of the encoding machine // The 'instructions' of the encoding machine
...@@ -82,8 +83,8 @@ type encInstr struct { ...@@ -82,8 +83,8 @@ type encInstr struct {
offset uintptr // offset in the structure of the field to encode offset uintptr // offset in the structure of the field to encode
} }
// Emit a field number and update the state to record its value for delta encoding. // update emits a field number and updates the state to record its value for delta encoding.
// If the instruction pointer is nil, do nothing // If the instruction pointer is nil, it does nothing
func (state *encoderState) update(instr *encInstr) { func (state *encoderState) update(instr *encInstr) {
if instr != nil { if instr != nil {
state.encodeUint(uint64(instr.field - state.fieldnum)) state.encodeUint(uint64(instr.field - state.fieldnum))
...@@ -97,6 +98,7 @@ func (state *encoderState) update(instr *encInstr) { ...@@ -97,6 +98,7 @@ func (state *encoderState) update(instr *encInstr) {
// Otherwise, the output (for a scalar) is the field number, as an encoded integer, // Otherwise, the output (for a scalar) is the field number, as an encoded integer,
// followed by the field data in its appropriate format. // followed by the field data in its appropriate format.
// encIndirect dereferences p indir times and returns the result.
func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
for ; indir > 0; indir-- { for ; indir > 0; indir-- {
p = *(*unsafe.Pointer)(p) p = *(*unsafe.Pointer)(p)
...@@ -107,6 +109,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { ...@@ -107,6 +109,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
return p return p
} }
// encBool encodes the bool with address p as an unsigned 0 or 1.
func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
b := *(*bool)(p) b := *(*bool)(p)
if b || state.sendZero { if b || state.sendZero {
...@@ -119,6 +122,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -119,6 +122,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encInt encodes the int with address p.
func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int)(p)) v := int64(*(*int)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -127,6 +131,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -127,6 +131,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encUint encodes the uint with address p.
func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint)(p)) v := uint64(*(*uint)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -135,6 +140,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -135,6 +140,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encInt8 encodes the int8 with address p.
func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int8)(p)) v := int64(*(*int8)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -143,6 +149,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -143,6 +149,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encUint8 encodes the uint8 with address p.
func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint8)(p)) v := uint64(*(*uint8)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -151,6 +158,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -151,6 +158,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encInt16 encodes the int16 with address p.
func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int16)(p)) v := int64(*(*int16)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -159,6 +167,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -159,6 +167,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encUint16 encodes the uint16 with address p.
func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint16)(p)) v := uint64(*(*uint16)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -167,6 +176,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -167,6 +176,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encInt32 encodes the int32 with address p.
func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int32)(p)) v := int64(*(*int32)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -175,6 +185,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -175,6 +185,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encUint encodes the uint32 with address p.
func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint32)(p)) v := uint64(*(*uint32)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -183,6 +194,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -183,6 +194,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encInt64 encodes the int64 with address p.
func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := *(*int64)(p) v := *(*int64)(p)
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -191,6 +203,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -191,6 +203,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encInt64 encodes the uint64 with address p.
func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := *(*uint64)(p) v := *(*uint64)(p)
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -199,6 +212,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -199,6 +212,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encUintptr encodes the uintptr with address p.
func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uintptr)(p)) v := uint64(*(*uintptr)(p))
if v != 0 || state.sendZero { if v != 0 || state.sendZero {
...@@ -207,6 +221,7 @@ func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -207,6 +221,7 @@ func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// floatBits returns a uint64 holding the bits of a floating-point number.
// Floating-point numbers are transmitted as uint64s holding the bits // Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with // of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers // the exponent end coming out first, so integer floating point numbers
...@@ -223,6 +238,7 @@ func floatBits(f float64) uint64 { ...@@ -223,6 +238,7 @@ func floatBits(f float64) uint64 {
return v return v
} }
// encFloat32 encodes the float32 with address p.
func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
f := *(*float32)(p) f := *(*float32)(p)
if f != 0 || state.sendZero { if f != 0 || state.sendZero {
...@@ -232,6 +248,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -232,6 +248,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encFloat64 encodes the float64 with address p.
func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
f := *(*float64)(p) f := *(*float64)(p)
if f != 0 || state.sendZero { if f != 0 || state.sendZero {
...@@ -241,6 +258,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -241,6 +258,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encComplex64 encodes the complex64 with address p.
// Complex numbers are just a pair of floating-point numbers, real part first. // Complex numbers are just a pair of floating-point numbers, real part first.
func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
c := *(*complex64)(p) c := *(*complex64)(p)
...@@ -253,6 +271,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -253,6 +271,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encComplex128 encodes the complex128 with address p.
func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
c := *(*complex128)(p) c := *(*complex128)(p)
if c != 0+0i || state.sendZero { if c != 0+0i || state.sendZero {
...@@ -264,6 +283,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -264,6 +283,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encUint8Array encodes the byte slice whose header has address p.
// Byte arrays are encoded as an unsigned count followed by the raw bytes. // Byte arrays are encoded as an unsigned count followed by the raw bytes.
func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
b := *(*[]byte)(p) b := *(*[]byte)(p)
...@@ -274,6 +294,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -274,6 +294,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// encString encodes the string whose header has address p.
// Strings are encoded as an unsigned count followed by the raw bytes. // Strings are encoded as an unsigned count followed by the raw bytes.
func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { func encString(i *encInstr, state *encoderState, p unsafe.Pointer) {
s := *(*string)(p) s := *(*string)(p)
...@@ -284,14 +305,15 @@ func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { ...@@ -284,14 +305,15 @@ func encString(i *encInstr, state *encoderState, p unsafe.Pointer) {
} }
} }
// The end of a struct is marked by a delta field number of 0. // encStructTerminator encodes the end of an encoded struct
// as delta field number of 0.
func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) { func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.encodeUint(0) state.encodeUint(0)
} }
// Execution engine // Execution engine
// The encoder engine is an array of instructions indexed by field number of the encoding // encEngine an array of instructions indexed by field number of the encoding
// data, typically a struct. It is executed top to bottom, walking the struct. // data, typically a struct. It is executed top to bottom, walking the struct.
type encEngine struct { type encEngine struct {
instr []encInstr instr []encInstr
...@@ -299,6 +321,7 @@ type encEngine struct { ...@@ -299,6 +321,7 @@ type encEngine struct {
const singletonField = 0 const singletonField = 0
// encodeSingle encodes a single top-level non-struct value.
func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintptr) { func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintptr) {
state := newEncoderState(enc, b) state := newEncoderState(enc, b)
state.fieldnum = singletonField state.fieldnum = singletonField
...@@ -315,6 +338,7 @@ func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintp ...@@ -315,6 +338,7 @@ func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintp
instr.op(instr, state, p) instr.op(instr, state, p)
} }
// encodeStruct encodes a single struct value.
func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintptr) { func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintptr) {
state := newEncoderState(enc, b) state := newEncoderState(enc, b)
state.fieldnum = -1 state.fieldnum = -1
...@@ -330,6 +354,7 @@ func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintp ...@@ -330,6 +354,7 @@ func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintp
} }
} }
// encodeArray encodes the array whose 0th element is at p.
func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) { func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) {
state := newEncoderState(enc, b) state := newEncoderState(enc, b)
state.fieldnum = -1 state.fieldnum = -1
...@@ -349,6 +374,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui ...@@ -349,6 +374,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui
} }
} }
// encodeReflectValue is a helper for maps. It encodes the value v.
func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
for i := 0; i < indir && v != nil; i++ { for i := 0; i < indir && v != nil; i++ {
v = reflect.Indirect(v) v = reflect.Indirect(v)
...@@ -359,6 +385,9 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in ...@@ -359,6 +385,9 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
op(nil, state, unsafe.Pointer(v.UnsafeAddr())) op(nil, state, unsafe.Pointer(v.UnsafeAddr()))
} }
// encodeMap encodes a map as unsigned count followed by key:value pairs.
// Because map internals are not exposed, we must use reflection rather than
// addresses.
func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) {
state := newEncoderState(enc, b) state := newEncoderState(enc, b)
state.fieldnum = -1 state.fieldnum = -1
...@@ -371,6 +400,7 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem ...@@ -371,6 +400,7 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem
} }
} }
// encodeInterface encodes the interface value iv.
// To send an interface, we send a string identifying the concrete type, followed // To send an interface, we send a string identifying the concrete type, followed
// by the type identifier (which might require defining that type right now), followed // by the type identifier (which might require defining that type right now), followed
// by the concrete value. A nil value gets sent as the empty string for the name, // by the concrete value. A nil value gets sent as the empty string for the name,
...@@ -414,6 +444,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) ...@@ -414,6 +444,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
} }
} }
// encGobEncoder encodes a value that implements the GobEncoder interface.
// The data is sent as a byte array.
func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value, index int) {
// TODO: should we catch panics from the called method?
// We know it's a GobEncoder, so just call the method directly.
data, err := v.Interface().(_GobEncoder)._GobEncode()
if err != nil {
error(err)
}
state := newEncoderState(enc, b)
state.fieldnum = -1
state.encodeUint(uint64(len(data)))
state.b.Write(data)
}
var encOpTable = [...]encOp{ var encOpTable = [...]encOp{
reflect.Bool: encBool, reflect.Bool: encBool,
reflect.Int: encInt, reflect.Int: encInt,
...@@ -434,10 +479,14 @@ var encOpTable = [...]encOp{ ...@@ -434,10 +479,14 @@ var encOpTable = [...]encOp{
reflect.String: encString, reflect.String: encString,
} }
// Return (a pointer to) the encoding op for the base type under rt and // encOpFor returns (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
ut := userType(rt) ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing.
if ut.isGobEncoder {
return enc.gobEncodeOpFor(ut)
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T). // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building. // Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil { if opPtr := inProgress[rt]; opPtr != nil {
...@@ -483,7 +532,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp ...@@ -483,7 +532,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
// Maps cannot be accessed by moving addresses around the way // Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for // that slices etc. can. We must recover a full reflection value for
// the iteration. // the iteration.
v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p)))
mv := reflect.Indirect(v).(*reflect.MapValue) mv := reflect.Indirect(v).(*reflect.MapValue)
if !state.sendZero && mv.Len() == 0 { if !state.sendZero && mv.Len() == 0 {
return return
...@@ -493,7 +542,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp ...@@ -493,7 +542,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
} }
case *reflect.StructType: case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enc.getEncEngine(typ) enc.getEncEngine(userType(typ))
info := mustGetTypeInfo(typ) info := mustGetTypeInfo(typ)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i) state.update(i)
...@@ -504,7 +553,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp ...@@ -504,7 +553,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Interfaces transmit the name and contents of the concrete // Interfaces transmit the name and contents of the concrete
// value they contain. // value they contain.
v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p)))
iv := reflect.Indirect(v).(*reflect.InterfaceValue) iv := reflect.Indirect(v).(*reflect.InterfaceValue)
if !state.sendZero && (iv == nil || iv.IsNil()) { if !state.sendZero && (iv == nil || iv.IsNil()) {
return return
...@@ -520,12 +569,43 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp ...@@ -520,12 +569,43 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
return &op, indir return &op, indir
} }
// The local Type was compiled from the actual value, so we know it's compatible. // gobEncodeOpFor returns the op for a type that is known to implement
func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { // GobEncoder.
srt, isStruct := rt.(*reflect.StructType) func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
rt := ut.user
if ut.encIndir != 0 {
errorf("gob: TODO: can't handle indirection to reach GobEncoder")
}
index := -1
for i := 0; i < rt.NumMethod(); i++ {
if rt.Method(i).Name == gobEncodeMethodName {
index = i
break
}
}
if index < 0 {
panic("can't find GobEncode method")
}
var op encOp
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// TODO: this will need fixing when ut.encIndr != 0.
v := reflect.NewValue(unsafe.Unreflect(rt, p))
state.update(i)
state.enc.encodeGobEncoder(state.b, v, index)
}
return &op, int(ut.encIndir)
}
// compileEnc returns the engine to compile the type.
func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine {
srt, isStruct := ut.base.(*reflect.StructType)
engine := new(encEngine) engine := new(encEngine)
seen := make(map[reflect.Type]*encOp) seen := make(map[reflect.Type]*encOp)
if isStruct { rt := ut.base
if ut.isGobEncoder {
rt = ut.user
}
if !ut.isGobEncoder && isStruct {
for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ { for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum) f := srt.Field(fieldNum)
if !isExported(f.Name) { if !isExported(f.Name) {
...@@ -546,35 +626,43 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { ...@@ -546,35 +626,43 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
return engine return engine
} }
// getEncEngine returns the engine to compile the type.
// typeLock must be held (or we're in initialization and guaranteed single-threaded). // typeLock must be held (or we're in initialization and guaranteed single-threaded).
// The reflection type must have all its indirections processed out. func (enc *Encoder) getEncEngine(ut *userTypeInfo) *encEngine {
func (enc *Encoder) getEncEngine(rt reflect.Type) *encEngine { info, err1 := getTypeInfo(ut)
info, err1 := getTypeInfo(rt)
if err1 != nil { if err1 != nil {
error(err1) error(err1)
} }
if info.encoder == nil { if info.encoder == nil {
// mark this engine as underway before compiling to handle recursive types. // mark this engine as underway before compiling to handle recursive types.
info.encoder = new(encEngine) info.encoder = new(encEngine)
info.encoder = enc.compileEnc(rt) info.encoder = enc.compileEnc(ut)
} }
return info.encoder return info.encoder
} }
// Put this in a function so we can hold the lock only while compiling, not when encoding. // lockAndGetEncEngine is a function that locks and compiles.
func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine { // This lets us hold the lock only while compiling, not when encoding.
func (enc *Encoder) lockAndGetEncEngine(ut *userTypeInfo) *encEngine {
typeLock.Lock() typeLock.Lock()
defer typeLock.Unlock() defer typeLock.Unlock()
return enc.getEncEngine(rt) return enc.getEncEngine(ut)
} }
func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) { func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) {
defer catchError(&err) defer catchError(&err)
for i := 0; i < ut.indir; i++ { engine := enc.lockAndGetEncEngine(ut)
indir := ut.indir
if ut.isGobEncoder {
indir = int(ut.encIndir)
if indir != 0 {
errorf("TODO: can't handle indirection in GobEncoder value")
}
}
for i := 0; i < indir; i++ {
value = reflect.Indirect(value) value = reflect.Indirect(value)
} }
engine := enc.lockAndGetEncEngine(ut.base) if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct {
if value.Type().Kind() == reflect.Struct {
enc.encodeStruct(b, engine, value.UnsafeAddr()) enc.encodeStruct(b, engine, value.UnsafeAddr())
} else { } else {
enc.encodeSingle(b, engine, value.UnsafeAddr()) enc.encodeSingle(b, engine, value.UnsafeAddr())
......
...@@ -78,12 +78,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { ...@@ -78,12 +78,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
} }
} }
// sendActualType sends the requested type, without further investigation, unless
// it's been sent before.
func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
if _, alreadySent := enc.sent[actual]; alreadySent {
return false
}
typeLock.Lock()
info, err := getTypeInfo(ut)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type, both what the user gave us and the base type.
enc.sent[ut.base] = info.id
if ut.user != ut.base {
enc.sent[ut.user] = info.id
}
// Now send the inner types
switch st := actual.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
enc.sendType(w, state, st.Elem())
}
return true
}
// sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
// Drill down to the base type.
ut := userType(origt) ut := userType(origt)
rt := ut.base if ut.isGobEncoder {
// The rules are different: regardless of the underlying type's representation,
// we need to tell the other side that this exact type is a GobEncoder.
return enc.sendActualType(w, state, ut, ut.user)
}
switch rt := rt.(type) { // It's a concrete value, so drill down to the base type.
switch rt := ut.base.(type) {
default: default:
// Basic types and interfaces do not need to be described. // Basic types and interfaces do not need to be described.
return return
...@@ -109,43 +154,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ ...@@ -109,43 +154,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
return return
} }
// Have we already sent this type? This time we ask about the base type. return enc.sendActualType(w, state, ut, ut.base)
if _, alreadySent := enc.sent[rt]; alreadySent {
return
}
// Need to send it.
typeLock.Lock()
info, err := getTypeInfo(rt)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type.
enc.sent[rt] = info.id
// Remember we've sent the top-level, possibly indirect type too.
enc.sent[origt] = info.id
// Now send the inner types
switch st := rt.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
enc.sendType(w, state, st.Elem())
}
return true
} }
// Encode transmits the data item represented by the empty interface value, // Encode transmits the data item represented by the empty interface value,
...@@ -159,11 +168,17 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -159,11 +168,17 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// sent. // sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side. // Make sure the type is known to the other side.
// First, have we already sent this (base) type? // First, have we already sent this type?
base := ut.base rt := ut.base
if _, alreadySent := enc.sent[base]; !alreadySent { if ut.isGobEncoder {
rt = ut.user
if ut.encIndir != 0 {
panic("TODO: can't handle non-zero encIndir")
}
}
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it. // No, so send it.
sent := enc.sendType(w, state, base) sent := enc.sendType(w, state, rt)
if enc.err != nil { if enc.err != nil {
return return
} }
...@@ -172,13 +187,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use ...@@ -172,13 +187,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use
// need to send the type info but we do need to update enc.sent. // need to send the type info but we do need to update enc.sent.
if !sent { if !sent {
typeLock.Lock() typeLock.Lock()
info, err := getTypeInfo(base) info, err := getTypeInfo(ut)
typeLock.Unlock() typeLock.Unlock()
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
return return
} }
enc.sent[base] = info.id enc.sent[rt] = info.id
} }
} }
} }
......
// Copyright 20011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains tests of the GobEncoder/GobDecoder support.
package gob
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
)
// Types that implement the GobEncoder/Decoder interfaces.
type ByteStruct struct {
a byte // not an exported field
}
type StringStruct struct {
s string // not an exported field
}
type Gobber int
type ValueGobber string // encodes with a value, decodes with a pointer.
// The relevant methods
func (g *ByteStruct) _GobEncode() ([]byte, os.Error) {
b := make([]byte, 3)
b[0] = g.a
b[1] = g.a + 1
b[2] = g.a + 2
return b, nil
}
func (g *ByteStruct) _GobDecode(data []byte) os.Error {
if g == nil {
return os.ErrorString("NIL RECEIVER")
}
// Expect N sequential-valued bytes.
if len(data) == 0 {
return os.EOF
}
g.a = data[0]
for i, c := range data {
if c != g.a+byte(i) {
return os.ErrorString("invalid data sequence")
}
}
return nil
}
func (g *StringStruct) _GobEncode() ([]byte, os.Error) {
return []byte(g.s), nil
}
func (g *StringStruct) _GobDecode(data []byte) os.Error {
// Expect N sequential-valued bytes.
if len(data) == 0 {
return os.EOF
}
a := data[0]
for i, c := range data {
if c != a+byte(i) {
return os.ErrorString("invalid data sequence")
}
}
g.s = string(data)
return nil
}
func (g *Gobber) _GobEncode() ([]byte, os.Error) {
return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
}
func (g *Gobber) _GobDecode(data []byte) os.Error {
_, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
return err
}
func (v ValueGobber) _GobEncode() ([]byte, os.Error) {
return []byte(fmt.Sprintf("VALUE=%s", v)), nil
}
func (v *ValueGobber) _GobDecode(data []byte) os.Error {
_, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
return err
}
// Structs that include GobEncodable fields.
type GobTest0 struct {
X int // guarantee we have something in common with GobTest*
G *ByteStruct
}
type GobTest1 struct {
X int // guarantee we have something in common with GobTest*
G *StringStruct
}
type GobTest2 struct {
X int // guarantee we have something in common with GobTest*
G string // not a GobEncoder - should give us errors
}
type GobTest3 struct {
X int // guarantee we have something in common with GobTest*
G *Gobber // TODO: should be able to satisfy interface without a pointer
}
type GobTest4 struct {
X int // guarantee we have something in common with GobTest*
V ValueGobber
}
type GobTest5 struct {
X int // guarantee we have something in common with GobTest*
V *ValueGobber
}
type GobTestIgnoreEncoder struct {
X int // guarantee we have something in common with GobTest*
}
func TestGobEncoderField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest0)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.a != 'A' {
t.Errorf("expected 'A' got %c", x.G.a)
}
// Now a field that's not a structure.
b.Reset()
gobber := Gobber(23)
err = enc.Encode(GobTest3{17, &gobber})
if err != nil {
t.Fatal("encode error:", err)
}
y := new(GobTest3)
err = dec.Decode(y)
if err != nil {
t.Fatal("decode error:", err)
}
if *y.G != 23 {
t.Errorf("expected '23 got %d", *y.G)
}
}
// As long as the fields have the same name and implement the
// interface, we can cross-connect them. Not sure it's useful
// and may even be bad but it works and it's hard to prevent
// without exposing the contents of the object, which would
// defeat the purpose.
func TestGobEncoderFieldsOfDifferentType(t *testing.T) {
// first, string in field to byte in field
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest0)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.a != 'A' {
t.Errorf("expected 'A' got %c", x.G.a)
}
// now the other direction, byte in field to string in field
b.Reset()
err = enc.Encode(GobTest0{17, &ByteStruct{'X'}})
if err != nil {
t.Fatal("encode error:", err)
}
y := new(GobTest1)
err = dec.Decode(y)
if err != nil {
t.Fatal("decode error:", err)
}
if y.G.s != "XYZ" {
t.Fatalf("expected `XYZ` got %c", y.G.s)
}
}
// Test that we can encode a value and decode into a pointer.
func TestGobEncoderValueEncoder(t *testing.T) {
// first, string in field to byte in field
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest4{17, ValueGobber("hello")})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest5)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if *x.V != "hello" {
t.Errorf("expected `hello` got %s", x.V)
}
}
func TestGobEncoderFieldTypeError(t *testing.T) {
// GobEncoder to non-decoder: error
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := &GobTest2{}
err = dec.Decode(x)
if err == nil {
t.Fatal("expected decode error for mistmatched fields (encoder to non-decoder)")
}
if strings.Index(err.String(), "type") < 0 {
t.Fatal("expected type error; got", err)
}
// Non-encoder to GobDecoder: error
b.Reset()
err = enc.Encode(GobTest2{17, "ABC"})
if err != nil {
t.Fatal("encode error:", err)
}
y := &GobTest1{}
err = dec.Decode(y)
if err == nil {
t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)")
}
if strings.Index(err.String(), "type") < 0 {
t.Fatal("expected type error; got", err)
}
}
// Even though ByteStruct is a struct, it's treated as a singleton at the top level.
func TestGobEncoderStructSingleton(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(&ByteStruct{'A'})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(ByteStruct)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.a != 'A' {
t.Errorf("expected 'A' got %c", x.a)
}
}
func TestGobEncoderNonStructSingleton(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
g := Gobber(1234) // TODO: shouldn't need to take the address here.
err := enc.Encode(&g)
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
var x Gobber
err = dec.Decode(&x)
if err != nil {
t.Fatal("decode error:", err)
}
if x != 1234 {
t.Errorf("expected 1234 got %c", x)
}
}
func TestGobEncoderIgnoreStructField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIgnoreEncoder)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.X != 17 {
t.Errorf("expected 17 got %c", x.X)
}
}
func TestGobEncoderIgnoreNonStructField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
gobber := Gobber(23)
err := enc.Encode(GobTest3{17, &gobber})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIgnoreEncoder)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.X != 17 {
t.Errorf("expected 17 got %c", x.X)
}
}
...@@ -18,6 +18,10 @@ type userTypeInfo struct { ...@@ -18,6 +18,10 @@ type userTypeInfo struct {
user reflect.Type // the type the user handed us user reflect.Type // the type the user handed us
base reflect.Type // the base type after all indirections base reflect.Type // the base type after all indirections
indir int // number of indirections to reach the base type indir int // number of indirections to reach the base type
isGobEncoder bool // does the type implement _GobEncoder?
isGobDecoder bool // does the type implement _GobDecoder?
encIndir int8 // number of indirections to reach the receiver type; may be negative
decIndir int8 // number of indirections to reach the receiver type; may be negative
} }
var ( var (
...@@ -68,8 +72,81 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { ...@@ -68,8 +72,81 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
} }
ut.indir++ ut.indir++
} }
ut.isGobEncoder, ut.encIndir = implementsGobEncoder(ut.user)
ut.isGobDecoder, ut.decIndir = implementsGobDecoder(ut.user)
userTypeCache[rt] = ut userTypeCache[rt] = ut
if ut.encIndir != 0 || ut.decIndir != 0 {
// There are checks in lots of other places, but putting this here means we won't even
// attempt to encode/decode this type.
// TODO: make it possible to handle types that are indirect to the implementation,
// such as a structure field of type T when *T implements GobDecoder.
return nil, os.ErrorString("TODO: gob can't handle indirections to GobEncoder/Decoder")
}
return
}
const (
gobEncodeMethodName = "_GobEncode"
gobDecodeMethodName = "_GobDecode"
)
// implementsGobEncoder reports whether the type implements the interface. It also
// returns the number of indirections required to get to the implementation.
// TODO: when reflection makes it possible, should also be prepared to climb up
// one level if we're not on a pointer (implementation could be on *T for our T).
// That will mean that indir could be < 0, which is sure to cause problems, but
// we ignore them now as indir is always >= 0 now.
func implementsGobEncoder(rt reflect.Type) (implements bool, indir int8) {
if rt == nil {
return
}
// The type might be a pointer, or it might not, and we need to keep
// dereferencing to the base type until we find an implementation.
for {
if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance
if _, ok := reflect.MakeZero(rt).Interface().(_GobEncoder); ok {
return true, indir
}
}
if p, ok := rt.(*reflect.PtrType); ok {
indir++
if indir > 100 { // insane number of indirections
return false, 0
}
rt = p.Elem()
continue
}
break
}
return false, 0
}
// implementsGobDecoder reports whether the type implements the interface. It also
// returns the number of indirections required to get to the implementation.
// TODO: see comment on implementsGobEncoder.
func implementsGobDecoder(rt reflect.Type) (implements bool, indir int8) {
if rt == nil {
return return
}
// The type might be a pointer, or it might not, and we need to keep
// dereferencing to the base type until we find an implementation.
for {
if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance
if _, ok := reflect.MakeZero(rt).Interface().(_GobDecoder); ok {
return true, indir
}
}
if p, ok := rt.(*reflect.PtrType); ok {
indir++
if indir > 100 { // insane number of indirections
return false, 0
}
rt = p.Elem()
continue
}
break
}
return false, 0
} }
// userType returns, and saves, the information associated with user-provided type rt. // userType returns, and saves, the information associated with user-provided type rt.
...@@ -229,6 +306,23 @@ func (a *arrayType) safeString(seen map[typeId]bool) string { ...@@ -229,6 +306,23 @@ func (a *arrayType) safeString(seen map[typeId]bool) string {
func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
// GobEncoder type (something that implements the _GobEncoder interface)
type gobEncoderType struct {
CommonType
}
func newGobEncoderType(name string) *gobEncoderType {
g := &gobEncoderType{CommonType{Name: name}}
setTypeId(g)
return g
}
func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
return g.Name
}
func (g *gobEncoderType) string() string { return g.Name }
// Map type // Map type
type mapType struct { type mapType struct {
CommonType CommonType
...@@ -328,7 +422,16 @@ func (s *structType) init(field []*fieldType) { ...@@ -328,7 +422,16 @@ func (s *structType) init(field []*fieldType) {
s.Field = field s.Field = field
} }
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { // newTypeObject allocates a gobType for the reflection type rt.
// Unless ut represents a GobEncoder, rt should be the base type
// of ut.
// This is only called from the encoding side. The decoding side
// works through typeIds and userTypeInfos alone.
func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
// Does this type implement GobEncoder?
if ut.isGobEncoder {
return newGobEncoderType(name), nil
}
var err os.Error var err os.Error
var type0, type1 gobType var type0, type1 gobType
defer func() { defer func() {
...@@ -364,7 +467,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -364,7 +467,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
case *reflect.ArrayType: case *reflect.ArrayType:
at := newArrayType(name) at := newArrayType(name)
types[rt] = at types[rt] = at
type0, err = getType("", t.Elem()) type0, err = getBaseType("", t.Elem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -382,11 +485,11 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -382,11 +485,11 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
case *reflect.MapType: case *reflect.MapType:
mt := newMapType(name) mt := newMapType(name)
types[rt] = mt types[rt] = mt
type0, err = getType("", t.Key()) type0, err = getBaseType("", t.Key())
if err != nil { if err != nil {
return nil, err return nil, err
} }
type1, err = getType("", t.Elem()) type1, err = getBaseType("", t.Elem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -400,7 +503,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -400,7 +503,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
} }
st := newSliceType(name) st := newSliceType(name)
types[rt] = st types[rt] = st
type0, err = getType(t.Elem().Name(), t.Elem()) type0, err = getBaseType(t.Elem().Name(), t.Elem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -413,6 +516,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -413,6 +516,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
idToType[st.id()] = st idToType[st.id()] = st
field := make([]*fieldType, t.NumField()) field := make([]*fieldType, t.NumField())
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
// TODO: don't send unexported fields.
f := t.Field(i) f := t.Field(i)
typ := userType(f.Type).base typ := userType(f.Type).base
tname := typ.Name() tname := typ.Name()
...@@ -420,7 +524,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -420,7 +524,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
t := userType(f.Type).base t := userType(f.Type).base
tname = t.String() tname = t.String()
} }
gt, err := getType(tname, f.Type) gt, err := getBaseType(tname, f.Type)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -435,15 +539,24 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -435,15 +539,24 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
return nil, nil return nil, nil
} }
// getBaseType returns the Gob type describing the given reflect.Type's base type.
// typeLock must be held.
func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
ut := userType(rt)
return getType(name, ut, ut.base)
}
// getType returns the Gob type describing the given reflect.Type. // getType returns the Gob type describing the given reflect.Type.
// Should be called only when handling GobEncoders/Decoders,
// which may be pointers. All other types are handled through the
// base type, never a pointer.
// typeLock must be held. // typeLock must be held.
func getType(name string, rt reflect.Type) (gobType, os.Error) { func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
rt = userType(rt).base
typ, present := types[rt] typ, present := types[rt]
if present { if present {
return typ, nil return typ, nil
} }
typ, err := newTypeObject(name, rt) typ, err := newTypeObject(name, ut, rt)
if err == nil { if err == nil {
types[rt] = typ types[rt] = typ
} }
...@@ -488,6 +601,7 @@ type wireType struct { ...@@ -488,6 +601,7 @@ type wireType struct {
SliceT *sliceType SliceT *sliceType
StructT *structType StructT *structType
MapT *mapType MapT *mapType
GobEncoderT *gobEncoderType
} }
func (w *wireType) string() string { func (w *wireType) string() string {
...@@ -504,6 +618,8 @@ func (w *wireType) string() string { ...@@ -504,6 +618,8 @@ func (w *wireType) string() string {
return w.StructT.Name return w.StructT.Name
case w.MapT != nil: case w.MapT != nil:
return w.MapT.Name return w.MapT.Name
case w.GobEncoderT != nil:
return w.GobEncoderT.Name
} }
return unknown return unknown
} }
...@@ -516,23 +632,43 @@ type typeInfo struct { ...@@ -516,23 +632,43 @@ type typeInfo struct {
var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
// The reflection type must have all its indirections processed out.
// typeLock must be held. // typeLock must be held.
func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
if rt.Kind() == reflect.Ptr {
panic("pointer type in getTypeInfo: " + rt.String()) if ut.isGobEncoder {
// TODO: clean up this code - too much duplication.
info, ok := typeInfoMap[ut.user]
if ok {
return info, nil
} }
info, ok := typeInfoMap[rt] // We want the user type, not the base type.
userType, err := getType(ut.user.Name(), ut, ut.user)
if err != nil {
return nil, err
}
info = new(typeInfo)
gt, err := getBaseType(ut.base.Name(), ut.base)
if err != nil {
return nil, err
}
info.id = gt.id()
info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
typeInfoMap[ut.user] = info
return info, nil
}
base := ut.base
info, ok := typeInfoMap[base]
if !ok { if !ok {
info = new(typeInfo) info = new(typeInfo)
name := rt.Name() name := base.Name()
gt, err := getType(name, rt) gt, err := getBaseType(name, base)
if err != nil { if err != nil {
return nil, err return nil, err
} }
info.id = gt.id() info.id = gt.id()
t := info.id.gobType() t := info.id.gobType()
switch typ := rt.(type) { switch typ := base.(type) {
case *reflect.ArrayType: case *reflect.ArrayType:
info.wire = &wireType{ArrayT: t.(*arrayType)} info.wire = &wireType{ArrayT: t.(*arrayType)}
case *reflect.MapType: case *reflect.MapType:
...@@ -545,20 +681,27 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { ...@@ -545,20 +681,27 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
case *reflect.StructType: case *reflect.StructType:
info.wire = &wireType{StructT: t.(*structType)} info.wire = &wireType{StructT: t.(*structType)}
} }
typeInfoMap[rt] = info typeInfoMap[base] = info
} }
return info, nil return info, nil
} }
// Called only when a panic is acceptable and unexpected. // Called only when a panic is acceptable and unexpected.
func mustGetTypeInfo(rt reflect.Type) *typeInfo { func mustGetTypeInfo(rt reflect.Type) *typeInfo {
t, err := getTypeInfo(rt) t, err := getTypeInfo(userType(rt))
if err != nil { if err != nil {
panic("getTypeInfo: " + err.String()) panic("getTypeInfo: " + err.String())
} }
return t return t
} }
type _GobEncoder interface {
_GobEncode() ([]byte, os.Error)
} // use _ prefix until we get it working properly
type _GobDecoder interface {
_GobDecode([]byte) os.Error
} // use _ prefix until we get it working properly
var ( var (
nameToConcreteType = make(map[string]reflect.Type) nameToConcreteType = make(map[string]reflect.Type)
concreteTypeToName = make(map[reflect.Type]string) concreteTypeToName = make(map[reflect.Type]string)
......
...@@ -26,7 +26,7 @@ var basicTypes = []typeT{ ...@@ -26,7 +26,7 @@ var basicTypes = []typeT{
func getTypeUnlocked(name string, rt reflect.Type) gobType { func getTypeUnlocked(name string, rt reflect.Type) gobType {
typeLock.Lock() typeLock.Lock()
defer typeLock.Unlock() defer typeLock.Unlock()
t, err := getType(name, rt) t, err := getBaseType(name, rt)
if err != nil { if err != nil {
panic("getTypeUnlocked: " + err.String()) panic("getTypeUnlocked: " + err.String())
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment