Commit c28fa513 authored by Rob Pike's avatar Rob Pike

gob: error cleanup 2

Simplify error handling during the compilation phase.

R=rsc
CC=golang-dev
https://golang.org/cl/2652042
parent f593b37f
...@@ -27,8 +27,7 @@ var ( ...@@ -27,8 +27,7 @@ var (
type decodeState struct { type decodeState struct {
// The buffer is stored with an extra indirection because it may be replaced // The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value). // if we load a type during decode (when reading an interface value).
b **bytes.Buffer b **bytes.Buffer
// err os.Error
fieldnum int // the last field number read. fieldnum int // the last field number read.
buf []byte buf []byte
} }
...@@ -77,14 +76,13 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) { ...@@ -77,14 +76,13 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) {
} }
// decodeUint reads an encoded unsigned integer from state.r. // decodeUint reads an encoded unsigned integer from state.r.
// Sets state.err. If state.err is already non-nil, it does nothing.
// Does not check for overflow. // Does not check for overflow.
func decodeUint(state *decodeState) (x uint64) { func decodeUint(state *decodeState) (x uint64) {
b, err := state.b.ReadByte() b, err := state.b.ReadByte()
if err != nil { if err != nil {
error(err) error(err)
} }
if b <= 0x7f { // includes state.err != nil if b <= 0x7f {
return uint64(b) return uint64(b)
} }
nb := -int(int8(b)) nb := -int(int8(b))
...@@ -105,7 +103,6 @@ func decodeUint(state *decodeState) (x uint64) { ...@@ -105,7 +103,6 @@ func decodeUint(state *decodeState) (x uint64) {
} }
// decodeInt reads an encoded signed integer from state.r. // decodeInt reads an encoded signed integer from state.r.
// Sets state.err. If state.err is already non-nil, it does nothing.
// Does not check for overflow. // Does not check for overflow.
func decodeInt(state *decodeState) int64 { func decodeInt(state *decodeState) int64 {
x := decodeUint(state) x := decodeUint(state)
...@@ -672,7 +669,7 @@ var decIgnoreOpMap = map[typeId]decOp{ ...@@ -672,7 +669,7 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and // Return the decoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) { func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
typ, indir := indirect(rt) typ, indir := indirect(rt)
var op decOp var op decOp
k := typ.Kind() k := typ.Kind()
...@@ -685,10 +682,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -685,10 +682,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
case *reflect.ArrayType: case *reflect.ArrayType:
name = "element of " + name name = "element of " + name
elemId := dec.wireType[wireId].arrayT.Elem elemId := dec.wireType[wireId].arrayT.Elem
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
if err != nil {
return nil, 0, err
}
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
...@@ -698,14 +692,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -698,14 +692,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
name = "element of " + name name = "element of " + name
keyId := dec.wireType[wireId].mapT.Key keyId := dec.wireType[wireId].mapT.Key
elemId := dec.wireType[wireId].mapT.Elem elemId := dec.wireType[wireId].mapT.Elem
keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name) keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name)
if err != nil { elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
return nil, 0, err
}
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
if err != nil {
return nil, 0, err
}
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
up := unsafe.Pointer(p) up := unsafe.Pointer(p)
...@@ -724,10 +712,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -724,10 +712,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} else { } else {
elemId = dec.wireType[wireId].sliceT.Elem elemId = dec.wireType[wireId].sliceT.Elem
} }
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
if err != nil {
return nil, 0, err
}
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
...@@ -737,7 +722,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -737,7 +722,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getDecEnginePtr(wireId, typ) enginePtr, err := dec.getDecEnginePtr(wireId, typ)
if err != nil { if err != nil {
return nil, 0, err error(err)
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs // indirect through enginePtr to delay evaluation for recursive structs
...@@ -753,13 +738,13 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -753,13 +738,13 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} }
} }
if op == nil { if op == nil {
return nil, 0, os.ErrorString("gob: decode can't handle type " + rt.String()) errorf("gob: decode can't handle type %s", rt.String())
} }
return op, indir, nil return op, indir
} }
// Return the decoding op for a field that has no destination. // Return the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
op, ok := decIgnoreOpMap[wireId] op, ok := decIgnoreOpMap[wireId]
if !ok { if !ok {
if wireId == tInterface { if wireId == tInterface {
...@@ -768,7 +753,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -768,7 +753,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
dec.ignoreInterface(state) dec.ignoreInterface(state)
} }
return op, nil return op
} }
// Special cases // Special cases
wire := dec.wireType[wireId] wire := dec.wireType[wireId]
...@@ -777,10 +762,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -777,10 +762,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
panic("internal error: can't find ignore op for type " + wireId.string()) panic("internal error: can't find ignore op for type " + wireId.string())
case wire.arrayT != nil: case wire.arrayT != nil:
elemId := wire.arrayT.Elem elemId := wire.arrayT.Elem
elemOp, err := dec.decIgnoreOpFor(elemId) elemOp := dec.decIgnoreOpFor(elemId)
if err != nil {
return nil, err
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
ignoreArray(state, elemOp, wire.arrayT.Len) ignoreArray(state, elemOp, wire.arrayT.Len)
} }
...@@ -788,24 +770,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -788,24 +770,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
case wire.mapT != nil: case wire.mapT != nil:
keyId := dec.wireType[wireId].mapT.Key keyId := dec.wireType[wireId].mapT.Key
elemId := dec.wireType[wireId].mapT.Elem elemId := dec.wireType[wireId].mapT.Elem
keyOp, err := dec.decIgnoreOpFor(keyId) keyOp := dec.decIgnoreOpFor(keyId)
if err != nil { elemOp := dec.decIgnoreOpFor(elemId)
return nil, err
}
elemOp, err := dec.decIgnoreOpFor(elemId)
if err != nil {
return nil, err
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
ignoreMap(state, keyOp, elemOp) ignoreMap(state, keyOp, elemOp)
} }
case wire.sliceT != nil: case wire.sliceT != nil:
elemId := wire.sliceT.Elem elemId := wire.sliceT.Elem
elemOp, err := dec.decIgnoreOpFor(elemId) elemOp := dec.decIgnoreOpFor(elemId)
if err != nil {
return nil, err
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
ignoreSlice(state, elemOp) ignoreSlice(state, elemOp)
} }
...@@ -814,7 +787,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -814,7 +787,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getIgnoreEnginePtr(wireId) enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil { if err != nil {
return nil, err error(err)
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs // indirect through enginePtr to delay evaluation for recursive structs
...@@ -823,9 +796,9 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -823,9 +796,9 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
} }
} }
if op == nil { if op == nil {
return nil, os.ErrorString("ignore can't handle type " + wireId.string()) errorf("ignore can't handle type %s", wireId.string())
} }
return op, nil return op
} }
// Are these two gob Types compatible? // Are these two gob Types compatible?
...@@ -892,10 +865,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec ...@@ -892,10 +865,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
if !dec.compatibleType(rt, remoteId) { if !dec.compatibleType(rt, remoteId) {
return nil, os.ErrorString("gob: wrong type received for local value " + name) return nil, os.ErrorString("gob: wrong type received for local value " + name)
} }
op, indir, err := dec.decOpFor(remoteId, rt, name) op, indir := dec.decOpFor(remoteId, rt, name)
if err != nil {
return nil, err
}
ovfl := os.ErrorString(`value for "` + name + `" out of range`) ovfl := os.ErrorString(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
engine.numInstr = 1 engine.numInstr = 1
...@@ -903,6 +873,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec ...@@ -903,6 +873,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
} }
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
defer catchError(&err)
srt, ok := rt.(*reflect.StructType) srt, ok := rt.(*reflect.StructType)
if !ok { if !ok {
return dec.compileSingle(remoteId, rt) return dec.compileSingle(remoteId, rt)
...@@ -916,8 +887,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng ...@@ -916,8 +887,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
wireStruct = dec.wireType[remoteId].structT wireStruct = dec.wireType[remoteId].structT
} }
if wireStruct == nil { if wireStruct == nil {
return nil, os.ErrorString("gob: type mismatch in decoder: want struct type " + errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String())
rt.String() + "; got non-struct")
} }
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.field)) engine.instr = make([]decInstr, len(wireStruct.field))
...@@ -929,22 +899,14 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng ...@@ -929,22 +899,14 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
ovfl := overflow(wireField.name) ovfl := overflow(wireField.name)
// TODO(r): anonymous names // TODO(r): anonymous names
if !present { if !present {
op, err := dec.decIgnoreOpFor(wireField.id) op := dec.decIgnoreOpFor(wireField.id)
if err != nil {
return nil, err
}
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl} engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl}
continue continue
} }
if !dec.compatibleType(localField.Type, wireField.id) { if !dec.compatibleType(localField.Type, wireField.id) {
return nil, os.ErrorString("gob: wrong type (" + errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.name, wireField.name)
localField.Type.String() + ") for received field " +
wireStruct.name + "." + wireField.name)
}
op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name)
if err != nil {
return nil, err
} }
op, indir := dec.decOpFor(wireField.id, localField.Type, localField.Name)
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl} engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl}
engine.numInstr++ engine.numInstr++
} }
......
...@@ -35,8 +35,7 @@ func newEncoderState(b *bytes.Buffer) *encoderState { ...@@ -35,8 +35,7 @@ func newEncoderState(b *bytes.Buffer) *encoderState {
// Otherwise the value is written in big-endian byte order preceded // Otherwise the value is written in big-endian byte order preceded
// by the byte length, negated. // by the byte length, negated.
// encodeUint writes an encoded unsigned integer to state.b. Sets state.err. // encodeUint writes an encoded unsigned integer to state.b.
// If state.err is already non-nil, it does nothing.
func encodeUint(state *encoderState, x uint64) { func encodeUint(state *encoderState, x uint64) {
if x <= 0x7F { if x <= 0x7F {
err := state.b.WriteByte(uint8(x)) err := state.b.WriteByte(uint8(x))
...@@ -60,8 +59,8 @@ func encodeUint(state *encoderState, x uint64) { ...@@ -60,8 +59,8 @@ func encodeUint(state *encoderState, x uint64) {
} }
// encodeInt writes an encoded signed integer to state.w. // encodeInt writes an encoded signed integer to state.w.
// The low bit of the encoding says whether to bit complement the (other bits of the) uint to recover the int. // The low bit of the encoding says whether to bit complement the (other bits of the)
// Sets state.err. If state.err is already non-nil, it does nothing. // uint to recover the int.
func encodeInt(state *encoderState, i int64) { func encodeInt(state *encoderState, i int64) {
var x uint64 var x uint64
if i < 0 { if i < 0 {
...@@ -319,8 +318,7 @@ type encEngine struct { ...@@ -319,8 +318,7 @@ type encEngine struct {
const singletonField = 0 const singletonField = 0
func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Error) { func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) {
defer catchError(&err)
state := newEncoderState(b) state := newEncoderState(b)
state.fieldnum = singletonField state.fieldnum = singletonField
// There is no surrounding struct to frame the transmission, so we must // There is no surrounding struct to frame the transmission, so we must
...@@ -330,15 +328,13 @@ func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Err ...@@ -330,15 +328,13 @@ func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Err
p := unsafe.Pointer(basep) // offset will be zero p := unsafe.Pointer(basep) // offset will be zero
if instr.indir > 0 { if instr.indir > 0 {
if p = encIndirect(p, instr.indir); p == nil { if p = encIndirect(p, instr.indir); p == nil {
return nil return
} }
} }
instr.op(instr, state, p) instr.op(instr, state, p)
return
} }
func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Error) { func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) {
defer catchError(&err)
state := newEncoderState(b) state := newEncoderState(b)
state.fieldnum = -1 state.fieldnum = -1
for i := 0; i < len(engine.instr); i++ { for i := 0; i < len(engine.instr); i++ {
...@@ -351,7 +347,6 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Err ...@@ -351,7 +347,6 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Err
} }
instr.op(instr, state, p) instr.op(instr, state, p)
} }
return nil
} }
func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) { func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) {
...@@ -452,7 +447,7 @@ var encOpMap = []encOp{ ...@@ -452,7 +447,7 @@ var encOpMap = []encOp{
// Return the encoding op for the base type under rt and // Return the encoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
typ, indir := indirect(rt) typ, indir := indirect(rt)
var op encOp var op encOp
k := typ.Kind() k := typ.Kind()
...@@ -468,10 +463,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -468,10 +463,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
break break
} }
// Slices have a header; we decode it to find the underlying array. // Slices have a header; we decode it to find the underlying array.
elemOp, indir, err := enc.encOpFor(t.Elem()) elemOp, indir := enc.encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
slice := (*reflect.SliceHeader)(p) slice := (*reflect.SliceHeader)(p)
if slice.Len == 0 { if slice.Len == 0 {
...@@ -482,23 +474,14 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -482,23 +474,14 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
} }
case *reflect.ArrayType: case *reflect.ArrayType:
// True arrays have size in the type. // True arrays have size in the type.
elemOp, indir, err := enc.encOpFor(t.Elem()) elemOp, indir := enc.encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i) state.update(i)
encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
} }
case *reflect.MapType: case *reflect.MapType:
keyOp, keyIndir, err := enc.encOpFor(t.Key()) keyOp, keyIndir := enc.encOpFor(t.Key())
if err != nil { elemOp, elemIndir := enc.encOpFor(t.Elem())
return nil, 0, err
}
elemOp, elemIndir, err := enc.encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Maps cannot be accessed by moving addresses around the way // Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for // that slices etc. can. We must recover a full reflection value for
...@@ -513,10 +496,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -513,10 +496,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
} }
case *reflect.StructType: case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
_, err := enc.getEncEngine(typ) enc.getEncEngine(typ)
if err != nil {
return nil, 0, err
}
info := mustGetTypeInfo(typ) info := mustGetTypeInfo(typ)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i) state.update(i)
...@@ -538,66 +518,65 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { ...@@ -538,66 +518,65 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
} }
} }
if op == nil { if op == nil {
return op, indir, os.ErrorString("gob enc: can't happen: encode type " + rt.String()) errorf("gob enc: can't happen: encode type %s", rt.String())
} }
return op, indir, nil return op, indir
} }
// The local Type was compiled from the actual value, so we know it's compatible. // The local Type was compiled from the actual value, so we know it's compatible.
func (enc *Encoder) compileEnc(rt reflect.Type) (*encEngine, os.Error) { func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
srt, isStruct := rt.(*reflect.StructType) srt, isStruct := rt.(*reflect.StructType)
engine := new(encEngine) engine := new(encEngine)
if isStruct { if isStruct {
engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator
for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ { for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ {
f := srt.Field(fieldnum) f := srt.Field(fieldnum)
op, indir, err := enc.encOpFor(f.Type) op, indir := enc.encOpFor(f.Type)
if err != nil {
return nil, err
}
engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)} engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)}
} }
engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0} engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0}
} else { } else {
engine.instr = make([]encInstr, 1) engine.instr = make([]encInstr, 1)
op, indir, err := enc.encOpFor(rt) op, indir := enc.encOpFor(rt)
if err != nil {
return nil, err
}
engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero
} }
return engine, nil return engine
} }
// typeLock must be held (or we're in initialization and guaranteed single-threaded). // typeLock must be held (or we're in initialization and guaranteed single-threaded).
// The reflection type must have all its indirections processed out. // The reflection type must have all its indirections processed out.
func (enc *Encoder) getEncEngine(rt reflect.Type) (*encEngine, os.Error) { func (enc *Encoder) getEncEngine(rt reflect.Type) *encEngine {
info, err := getTypeInfo(rt) info, err1 := getTypeInfo(rt)
if err != nil { if err1 != nil {
return nil, err error(err1)
} }
if info.encoder == nil { if info.encoder == nil {
// mark this engine as underway before compiling to handle recursive types. // mark this engine as underway before compiling to handle recursive types.
info.encoder = new(encEngine) info.encoder = new(encEngine)
info.encoder, err = enc.compileEnc(rt) info.encoder = enc.compileEnc(rt)
} }
return info.encoder, err return info.encoder
} }
func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) os.Error { // Put this in a function so we can hold the lock only while compiling, not when encoding.
func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine {
typeLock.Lock()
defer typeLock.Unlock()
return enc.getEncEngine(rt)
}
func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) (err os.Error) {
defer catchError(&err)
// Dereference down to the underlying object. // Dereference down to the underlying object.
rt, indir := indirect(value.Type()) rt, indir := indirect(value.Type())
for i := 0; i < indir; i++ { for i := 0; i < indir; i++ {
value = reflect.Indirect(value) value = reflect.Indirect(value)
} }
typeLock.Lock() engine := enc.lockAndGetEncEngine(rt)
engine, err := enc.getEncEngine(rt)
typeLock.Unlock()
if err != nil {
return err
}
if value.Type().Kind() == reflect.Struct { if value.Type().Kind() == reflect.Struct {
return encodeStruct(engine, b, value.Addr()) encodeStruct(engine, b, value.Addr())
} else {
encodeSingle(engine, b, value.Addr())
} }
return encodeSingle(engine, b, value.Addr()) return nil
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment