Commit 9b82481a authored by Rob Pike's avatar Rob Pike

gob: make nested interfaces work.

Also clean up the code, make it more regular.

Fixes #1416.

R=rsc
CC=golang-dev
https://golang.org/cl/3985047
parent 04a89054
...@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) { ...@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) {
t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes()) t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
} }
} }
decState := newDecodeState(nil, &b) decState := newDecodeState(nil, b)
for u := uint64(0); ; u = (u + 1) * 7 { for u := uint64(0); ; u = (u + 1) * 7 {
b.Reset() b.Reset()
encState.encodeUint(u) encState.encodeUint(u)
...@@ -77,7 +77,7 @@ func verifyInt(i int64, t *testing.T) { ...@@ -77,7 +77,7 @@ func verifyInt(i int64, t *testing.T) {
var b = new(bytes.Buffer) var b = new(bytes.Buffer)
encState := newEncoderState(nil, b) encState := newEncoderState(nil, b)
encState.encodeInt(i) encState.encodeInt(i)
decState := newDecodeState(nil, &b) decState := newDecodeState(nil, b)
decState.buf = make([]byte, 8) decState.buf = make([]byte, 8)
j := decState.decodeInt() j := decState.decodeInt()
if i != j { if i != j {
...@@ -315,7 +315,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un ...@@ -315,7 +315,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
func newDecodeStateFromData(data []byte) *decodeState { func newDecodeStateFromData(data []byte) *decodeState {
b := bytes.NewBuffer(data) b := bytes.NewBuffer(data)
state := newDecodeState(nil, &b) state := newDecodeState(nil, b)
state.fieldnum = -1 state.fieldnum = -1
return state return state
} }
...@@ -1162,7 +1162,6 @@ func TestInterface(t *testing.T) { ...@@ -1162,7 +1162,6 @@ func TestInterface(t *testing.T) {
} }
} }
} }
} }
// A struct with all basic types, stored in interfaces. // A struct with all basic types, stored in interfaces.
...@@ -1182,7 +1181,7 @@ func TestInterfaceBasic(t *testing.T) { ...@@ -1182,7 +1181,7 @@ func TestInterfaceBasic(t *testing.T) {
int(1), int8(1), int16(1), int32(1), int64(1), int(1), int8(1), int16(1), int32(1), int64(1),
uint(1), uint8(1), uint16(1), uint32(1), uint64(1), uint(1), uint8(1), uint16(1), uint32(1), uint64(1),
float32(1), 1.0, float32(1), 1.0,
complex64(0i), complex128(0i), complex64(1i), complex128(1i),
true, true,
"hello", "hello",
[]byte("sailor"), []byte("sailor"),
......
...@@ -30,15 +30,17 @@ type decodeState struct { ...@@ -30,15 +30,17 @@ type decodeState struct {
dec *Decoder dec *Decoder
// The buffer is stored with an extra indirection because it may be replaced // The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value). // if we load a type during decode (when reading an interface value).
b **bytes.Buffer b *bytes.Buffer
fieldnum int // the last field number read. fieldnum int // the last field number read.
buf []byte buf []byte
} }
func newDecodeState(dec *Decoder, b **bytes.Buffer) *decodeState { // We pass the bytes.Buffer separately for easier testing of the infrastructure
// without requiring a full Decoder.
func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState {
d := new(decodeState) d := new(decodeState)
d.dec = dec d.dec = dec
d.b = b d.b = buf
d.buf = make([]byte, uint64Size) d.buf = make([]byte, uint64Size)
return d return d
} }
...@@ -407,10 +409,10 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { ...@@ -407,10 +409,10 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
return *(*uintptr)(up) return *(*uintptr)(up)
} }
func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr, indir int) (err os.Error) {
defer catchError(&err) defer catchError(&err)
p = allocate(rtyp, p, indir) p = allocate(rtyp, p, indir)
state := newDecodeState(dec, b) state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField state.fieldnum = singletonField
basep := p basep := p
delta := int(state.decodeUint()) delta := int(state.decodeUint())
...@@ -426,10 +428,10 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes ...@@ -426,10 +428,10 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes
return nil return nil
} }
func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p uintptr, indir int) (err os.Error) {
defer catchError(&err) defer catchError(&err)
p = allocate(rtyp, p, indir) p = allocate(rtyp, p, indir)
state := newDecodeState(dec, b) state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1 state.fieldnum = -1
basep := p basep := p
for state.b.Len() > 0 { for state.b.Len() > 0 {
...@@ -456,9 +458,9 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b ...@@ -456,9 +458,9 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b
return nil return nil
} }
func (dec *Decoder) ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) { func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
defer catchError(&err) defer catchError(&err)
state := newDecodeState(dec, b) state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1 state.fieldnum = -1
for state.b.Len() > 0 { for state.b.Len() > 0 {
delta := int(state.decodeUint()) delta := int(state.decodeUint())
...@@ -614,9 +616,17 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt ...@@ -614,9 +616,17 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
if !ok { if !ok {
errorf("gob: name not registered for interface: %q", name) errorf("gob: name not registered for interface: %q", name)
} }
// Read the type id of the concrete value.
concreteId := dec.decodeTypeSequence(true)
if concreteId < 0 {
error(dec.err)
}
// Byte count of value is next; we don't care what it is (it's there
// in case we want to ignore the value by skipping it completely).
state.decodeUint()
// Read the concrete value. // Read the concrete value.
value := reflect.MakeZero(typ) value := reflect.MakeZero(typ)
dec.decodeValueFromBuffer(value, false, true) dec.decodeValue(concreteId, value)
if dec.err != nil { if dec.err != nil {
error(dec.err) error(dec.err)
} }
...@@ -639,10 +649,12 @@ func (dec *Decoder) ignoreInterface(state *decodeState) { ...@@ -639,10 +649,12 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
if err != nil { if err != nil {
error(err) error(err)
} }
dec.decodeValueFromBuffer(nil, true, true) id := dec.decodeTypeSequence(true)
if dec.err != nil { if id < 0 {
error(err) error(dec.err)
} }
// At this point, the decoder buffer contains the value. Just toss it.
state.b.Reset()
} }
// Index by Go types. // Index by Go types.
...@@ -733,7 +745,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -733,7 +745,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs // indirect through enginePtr to delay evaluation for recursive structs
err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir) err = dec.decodeStruct(*enginePtr, t, uintptr(p), i.indir)
if err != nil { if err != nil {
error(err) error(err)
} }
...@@ -798,7 +810,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -798,7 +810,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs // indirect through enginePtr to delay evaluation for recursive structs
state.dec.ignoreStruct(*enginePtr, state.b) state.dec.ignoreStruct(*enginePtr)
} }
} }
} }
...@@ -907,7 +919,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng ...@@ -907,7 +919,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
if t, ok := builtinIdToType[remoteId]; ok { if t, ok := builtinIdToType[remoteId]; ok {
wireStruct, _ = t.(*structType) wireStruct, _ = t.(*structType)
} else { } else {
wireStruct = dec.wireType[remoteId].StructT wire := dec.wireType[remoteId]
if wire == nil {
error(errBadType)
}
wireStruct = wire.StructT
} }
if wireStruct == nil { if wireStruct == nil {
errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String()) errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String())
...@@ -976,7 +992,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er ...@@ -976,7 +992,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
return return
} }
func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error { func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
// Dereference down to the underlying struct type. // Dereference down to the underlying struct type.
rt, indir := indirect(val.Type()) rt, indir := indirect(val.Type())
enginePtr, err := dec.getDecEnginePtr(wireId, rt) enginePtr, err := dec.getDecEnginePtr(wireId, rt)
...@@ -989,9 +1005,9 @@ func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error { ...@@ -989,9 +1005,9 @@ func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error {
name := rt.Name() name := rt.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
} }
return dec.decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir) return dec.decodeStruct(engine, st, uintptr(val.Addr()), indir)
} }
return dec.decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir) return dec.decodeSingle(engine, rt, uintptr(val.Addr()), indir)
} }
func init() { func init() {
......
...@@ -17,14 +17,13 @@ import ( ...@@ -17,14 +17,13 @@ import (
type Decoder struct { type Decoder struct {
mutex sync.Mutex // each item must be received atomically mutex sync.Mutex // each item must be received atomically
r io.Reader // source of the data r io.Reader // source of the data
buf bytes.Buffer // buffer for more efficient i/o from r
wireType map[typeId]*wireType // map from remote ID to local description wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects ignorerCache map[typeId]**decEngine // ditto for ignored objects
state *decodeState // reads data from in-memory buffer
countState *decodeState // reads counts from wire countState *decodeState // reads counts from wire
buf []byte countBuf []byte // used for decoding integers while parsing messages
countBuf [9]byte // counts may be uint64s (unlikely!), require 9 bytes tmp []byte // temporary storage for i/o; saves reallocating
byteBuffer *bytes.Buffer
err os.Error err os.Error
} }
...@@ -33,116 +32,138 @@ func NewDecoder(r io.Reader) *Decoder { ...@@ -33,116 +32,138 @@ func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder) dec := new(Decoder)
dec.r = r dec.r = r
dec.wireType = make(map[typeId]*wireType) dec.wireType = make(map[typeId]*wireType)
dec.state = newDecodeState(dec, &dec.byteBuffer) // buffer set in Decode()
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine) dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine) dec.ignorerCache = make(map[typeId]**decEngine)
dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
return dec return dec
} }
// recvType loads the definition of a type and reloads the Decoder's buffer. // recvType loads the definition of a type.
func (dec *Decoder) recvType(id typeId) { func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error // Have we already seen this type? That's an error
if dec.wireType[id] != nil { if id < firstUserId || dec.wireType[id] != nil {
dec.err = os.ErrorString("gob: duplicate type received") dec.err = os.ErrorString("gob: duplicate type received")
return return
} }
// Type: // Type:
wire := new(wireType) wire := new(wireType)
dec.err = dec.decode(tWireType, reflect.NewValue(wire)) dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire))
if dec.err != nil { if dec.err != nil {
return return
} }
// Remember we've seen this type. // Remember we've seen this type.
dec.wireType[id] = wire dec.wireType[id] = wire
// Load the next parcel.
dec.recvMessage()
}
// Decode reads the next value from the connection and stores
// it in the data represented by the empty interface value.
// The value underlying e must be the correct type for the next
// data item received, and must be a pointer.
func (dec *Decoder) Decode(e interface{}) os.Error {
value := reflect.NewValue(e)
// If e represents a value as opposed to a pointer, the answer won't
// get back to the caller. Make sure it's a pointer.
if value.Type().Kind() != reflect.Ptr {
dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
return dec.err
}
return dec.DecodeValue(value)
} }
// recvMessage reads the next count-delimited item from the input. It is the converse // recvMessage reads the next count-delimited item from the input. It is the converse
// of Encoder.writeMessage. // of Encoder.writeMessage. It returns false on EOF or other error reading the message.
func (dec *Decoder) recvMessage() { func (dec *Decoder) recvMessage() bool {
// Read a count. // Read a count.
var nbytes uint64 nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
nbytes, _, dec.err = decodeUintReader(dec.r, dec.countBuf[0:]) if err != nil {
if dec.err != nil { dec.err = err
return return false
} }
dec.readMessage(int(nbytes), dec.r) dec.readMessage(int(nbytes))
return dec.err == nil
} }
// readMessage reads the next nbytes bytes from the input. // readMessage reads the next nbytes bytes from the input.
func (dec *Decoder) readMessage(nbytes int, r io.Reader) { func (dec *Decoder) readMessage(nbytes int) {
// Allocate the buffer. // Allocate the buffer.
if nbytes > len(dec.buf) { if cap(dec.tmp) < nbytes {
dec.buf = make([]byte, nbytes+1000) dec.tmp = make([]byte, nbytes+100) // room to grow
} }
dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes]) dec.tmp = dec.tmp[:nbytes]
// Read the data // Read the data
_, dec.err = io.ReadFull(r, dec.buf[0:nbytes]) _, dec.err = io.ReadFull(dec.r, dec.tmp)
if dec.err != nil { if dec.err != nil {
if dec.err == os.EOF { if dec.err == os.EOF {
dec.err = io.ErrUnexpectedEOF dec.err = io.ErrUnexpectedEOF
} }
return return
} }
dec.buf.Write(dec.tmp)
} }
// decodeValueFromBuffer grabs the next value from the input. The Decoder's // toInt turns an encoded uint64 into an int, according to the marshaling rules.
// buffer already contains data. If the next item in the buffer is a type func toInt(x uint64) int64 {
// descriptor, it will be necessary to reload the buffer; recvType does that. i := int64(x >> 1)
func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignoreInterfaceValue, countPresent bool) { if x&1 != 0 {
for dec.state.b.Len() > 0 { i = ^i
// Receive a type id. }
id := typeId(dec.state.decodeInt()) return i
}
func (dec *Decoder) nextInt() int64 {
n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
if err != nil {
dec.err = err
}
return toInt(n)
}
func (dec *Decoder) nextUint() uint64 {
n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
if err != nil {
dec.err = err
}
return n
}
// Is it a new type? // decodeTypeSequence parses:
if id < 0 { // 0 is the error state, handled above // TypeSequence
// If the id is negative, we have a type. // (TypeDefinition DelimitedTypeDefinition*)?
dec.recvType(-id) // and returns the type id of the next value. It returns -1 at
if dec.err != nil { // EOF. Upon return, the remainder of dec.buf is the value to be
// decoded. If this is an interface value, it can be ignored by
// simply resetting that buffer.
func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
for dec.err == nil {
if dec.buf.Len() == 0 {
if !dec.recvMessage() {
break break
} }
continue
} }
// Receive a type id.
// Make sure the type has been defined already or is a builtin type (for id := typeId(dec.nextInt())
// top-level singleton values). if id >= 0 {
if dec.wireType[id] == nil && builtinIdToType[id] == nil { // Value follows.
dec.err = errBadType return id
break
} }
// An interface value is preceded by a byte count. // Type definition for (-id) follows.
if countPresent { dec.recvType(-id)
count := int(dec.state.decodeUint()) // When decoding an interface, after a type there may be a
if ignoreInterfaceValue { // DelimitedValue still in the buffer. Skip its count.
// An interface value is preceded by a byte count. Just skip that many bytes. // (Alternatively, the buffer is empty and the byte count
dec.state.b.Next(int(count)) // will be absorbed by recvMessage.)
if dec.buf.Len() > 0 {
if !isInterface {
dec.err = os.ErrorString("extra data in buffer")
break break
} }
// Otherwise fall through and decode it. dec.nextUint()
} }
dec.err = dec.decode(id, value)
break
} }
return -1
}
// Decode reads the next value from the connection and stores
// it in the data represented by the empty interface value.
// The value underlying e must be the correct type for the next
// data item received, and must be a pointer.
func (dec *Decoder) Decode(e interface{}) os.Error {
value := reflect.NewValue(e)
// If e represents a value as opposed to a pointer, the answer won't
// get back to the caller. Make sure it's a pointer.
if value.Type().Kind() != reflect.Ptr {
dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
return dec.err
}
return dec.DecodeValue(value)
} }
// DecodeValue reads the next value from the connection and stores // DecodeValue reads the next value from the connection and stores
...@@ -154,12 +175,12 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { ...@@ -154,12 +175,12 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
dec.mutex.Lock() dec.mutex.Lock()
defer dec.mutex.Unlock() defer dec.mutex.Unlock()
dec.buf.Reset() // In case data lingers from previous invocation.
dec.err = nil dec.err = nil
dec.recvMessage() id := dec.decodeTypeSequence(false)
if dec.err != nil { if id >= 0 {
return dec.err dec.err = dec.decodeValue(id, value)
} }
dec.decodeValueFromBuffer(value, false, false)
return dec.err return dec.err
} }
......
...@@ -395,17 +395,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) ...@@ -395,17 +395,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
if err != nil { if err != nil {
error(err) error(err)
} }
// Send (and maybe first define) the type id. // Define the type id if necessary.
enc.sendTypeDescriptor(typ) enc.sendTypeDescriptor(enc.writer(), state, typ)
// Encode the value into a new buffer. // Send the type id.
enc.sendTypeId(state, typ)
// Encode the value into a new buffer. Any nested type definitions
// should be written to b, before the encoded value.
enc.pushWriter(b)
data := new(bytes.Buffer) data := new(bytes.Buffer)
err = enc.encode(data, iv.Elem()) err = enc.encode(data, iv.Elem())
if err != nil { if err != nil {
error(err) error(err)
} }
state.encodeUint(uint64(data.Len())) enc.popWriter()
_, err = state.b.Write(data.Bytes()) enc.writeMessage(b, data)
if err != nil { if enc.err != nil {
error(err) error(err)
} }
} }
......
...@@ -16,9 +16,8 @@ import ( ...@@ -16,9 +16,8 @@ import (
// other side of a connection. // other side of a connection.
type Encoder struct { type Encoder struct {
mutex sync.Mutex // each item must be sent atomically mutex sync.Mutex // each item must be sent atomically
w io.Writer // where to send the data w []io.Writer // where to send the data
sent map[reflect.Type]typeId // which types we've already sent sent map[reflect.Type]typeId // which types we've already sent
state *encoderState // so we can encode integers, strings directly
countState *encoderState // stage for writing counts countState *encoderState // stage for writing counts
buf []byte // for collecting the output. buf []byte // for collecting the output.
err os.Error err os.Error
...@@ -27,13 +26,27 @@ type Encoder struct { ...@@ -27,13 +26,27 @@ type Encoder struct {
// NewEncoder returns a new encoder that will transmit on the io.Writer. // NewEncoder returns a new encoder that will transmit on the io.Writer.
func NewEncoder(w io.Writer) *Encoder { func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder) enc := new(Encoder)
enc.w = w enc.w = []io.Writer{w}
enc.sent = make(map[reflect.Type]typeId) enc.sent = make(map[reflect.Type]typeId)
enc.state = newEncoderState(enc, new(bytes.Buffer))
enc.countState = newEncoderState(enc, new(bytes.Buffer)) enc.countState = newEncoderState(enc, new(bytes.Buffer))
return enc return enc
} }
// writer() returns the innermost writer the encoder is using
func (enc *Encoder) writer() io.Writer {
return enc.w[len(enc.w)-1]
}
// pushWriter adds a writer to the encoder.
func (enc *Encoder) pushWriter(w io.Writer) {
enc.w = append(enc.w, w)
}
// popWriter pops the innermost writer.
func (enc *Encoder) popWriter() {
enc.w = enc.w[0 : len(enc.w)-1]
}
func (enc *Encoder) badType(rt reflect.Type) { func (enc *Encoder) badType(rt reflect.Type) {
enc.setError(os.ErrorString("gob: can't encode type " + rt.String())) enc.setError(os.ErrorString("gob: can't encode type " + rt.String()))
} }
...@@ -42,16 +55,14 @@ func (enc *Encoder) setError(err os.Error) { ...@@ -42,16 +55,14 @@ func (enc *Encoder) setError(err os.Error) {
if enc.err == nil { // remember the first. if enc.err == nil { // remember the first.
enc.err = err enc.err = err
} }
enc.state.b.Reset()
} }
// Send the data item preceded by a unsigned count of its length. // writeMessage sends the data item preceded by a unsigned count of its length.
func (enc *Encoder) send() { func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
// Encode the length. enc.countState.encodeUint(uint64(b.Len()))
enc.countState.encodeUint(uint64(enc.state.b.Len()))
// Build the buffer. // Build the buffer.
countLen := enc.countState.b.Len() countLen := enc.countState.b.Len()
total := countLen + enc.state.b.Len() total := countLen + b.Len()
if total > len(enc.buf) { if total > len(enc.buf) {
enc.buf = make([]byte, total+1000) // extra for growth enc.buf = make([]byte, total+1000) // extra for growth
} }
...@@ -59,15 +70,15 @@ func (enc *Encoder) send() { ...@@ -59,15 +70,15 @@ func (enc *Encoder) send() {
// TODO(r): avoid the extra copy here. // TODO(r): avoid the extra copy here.
enc.countState.b.Read(enc.buf[0:countLen]) enc.countState.b.Read(enc.buf[0:countLen])
// Now the data. // Now the data.
enc.state.b.Read(enc.buf[countLen:total]) b.Read(enc.buf[countLen:total])
// Write the data. // Write the data.
_, err := enc.w.Write(enc.buf[0:total]) _, err := w.Write(enc.buf[0:total])
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
} }
} }
func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
// Drill down to the base type. // Drill down to the base type.
rt, _ := indirect(origt) rt, _ := indirect(origt)
...@@ -112,10 +123,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { ...@@ -112,10 +123,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
} }
// Send the pair (-id, type) // Send the pair (-id, type)
// Id: // Id:
enc.state.encodeInt(-int64(info.id)) state.encodeInt(-int64(info.id))
// Type: // Type:
enc.encode(enc.state.b, reflect.NewValue(info.wire)) enc.encode(state.b, reflect.NewValue(info.wire))
enc.send() enc.writeMessage(w, state.b)
if enc.err != nil { if enc.err != nil {
return return
} }
...@@ -128,10 +139,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { ...@@ -128,10 +139,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
switch st := rt.(type) { switch st := rt.(type) {
case *reflect.StructType: case *reflect.StructType:
for i := 0; i < st.NumField(); i++ { for i := 0; i < st.NumField(); i++ {
enc.sendType(st.Field(i).Type) enc.sendType(w, state, st.Field(i).Type)
} }
case reflect.ArrayOrSliceType: case reflect.ArrayOrSliceType:
enc.sendType(st.Elem()) enc.sendType(w, state, st.Elem())
} }
return true return true
} }
...@@ -144,13 +155,13 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -144,13 +155,13 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// sendTypeId makes sure the remote side knows about this type. // sendTypeId makes sure the remote side knows about this type.
// It will send a descriptor if this is the first time the type has been // It will send a descriptor if this is the first time the type has been
// sent. Regardless, it sends the id. // sent.
func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) { func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt reflect.Type) {
// Make sure the type is known to the other side. // Make sure the type is known to the other side.
// First, have we already sent this type? // First, have we already sent this type?
if _, alreadySent := enc.sent[rt]; !alreadySent { if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it. // No, so send it.
sent := enc.sendType(rt) sent := enc.sendType(w, state, rt)
if enc.err != nil { if enc.err != nil {
return return
} }
...@@ -168,9 +179,12 @@ func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) { ...@@ -168,9 +179,12 @@ func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) {
enc.sent[rt] = info.id enc.sent[rt] = info.id
} }
} }
}
// sendTypeId sends the id, which must have already been defined.
func (enc *Encoder) sendTypeId(state *encoderState, rt reflect.Type) {
// Identify the type of this top-level value. // Identify the type of this top-level value.
enc.state.encodeInt(int64(enc.sent[rt])) state.encodeInt(int64(enc.sent[rt]))
} }
// EncodeValue transmits the data item represented by the reflection value, // EncodeValue transmits the data item represented by the reflection value,
...@@ -181,26 +195,26 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { ...@@ -181,26 +195,26 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
enc.mutex.Lock() enc.mutex.Lock()
defer enc.mutex.Unlock() defer enc.mutex.Unlock()
// Remove any nested writers remaining due to previous errors.
enc.w = enc.w[0:1]
enc.err = nil enc.err = nil
rt, _ := indirect(value.Type()) rt, _ := indirect(value.Type())
// Sanity check only: encoder should never come in with data present. state := newEncoderState(enc, new(bytes.Buffer))
if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
enc.err = os.ErrorString("encoder: buffer not empty")
return enc.err
}
enc.sendTypeDescriptor(rt) enc.sendTypeDescriptor(enc.writer(), state, rt)
enc.sendTypeId(state, rt)
if enc.err != nil { if enc.err != nil {
return enc.err return enc.err
} }
// Encode the object. // Encode the object.
err := enc.encode(enc.state.b, value) err := enc.encode(state.b, value)
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
} else { } else {
enc.send() enc.writeMessage(enc.writer(), state.b)
} }
return enc.err return enc.err
......
...@@ -383,3 +383,47 @@ func TestInterfaceIndirect(t *testing.T) { ...@@ -383,3 +383,47 @@ func TestInterfaceIndirect(t *testing.T) {
t.Fatal("decode error:", err) t.Fatal("decode error:", err)
} }
} }
// Another bug from golang-nuts, involving nested interfaces.
type Bug0Outer struct {
Bug0Field interface{}
}
type Bug0Inner struct {
A int
}
func TestNestedInterfaces(t *testing.T) {
var buf bytes.Buffer
e := NewEncoder(&buf)
d := NewDecoder(&buf)
Register(new(Bug0Outer))
Register(new(Bug0Inner))
f := &Bug0Outer{&Bug0Outer{&Bug0Inner{7}}}
var v interface{} = f
err := e.Encode(&v)
if err != nil {
t.Fatal("Encode:", err)
}
Debug(bytes.NewBuffer(buf.Bytes()))
err = d.Decode(&v)
if err != nil {
t.Fatal("Decode:", err)
}
// Make sure it decoded correctly.
outer1, ok := v.(*Bug0Outer)
if !ok {
t.Fatalf("v not Bug0Outer: %T", v)
}
outer2, ok := outer1.Bug0Field.(*Bug0Outer)
if !ok {
t.Fatalf("v.Bug0Field not Bug0Outer: %T", outer1.Bug0Field)
}
inner, ok := outer2.Bug0Field.(*Bug0Inner)
if !ok {
t.Fatalf("v.Bug0Field.Bug0Field not Bug0Inner: %T", outer2.Bug0Field)
}
if inner.A != 7 {
t.Fatalf("final value %d; expected %d", inner.A, 7)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment