Commit 30b1b9a3 authored by Rob Pike's avatar Rob Pike

Rework gobs to fix bad bug related to sharing of id's between encoder and decoder side.

Fix is to move all decoder state into the decoder object.

Fixes #215.

R=rsc
CC=golang-dev
https://golang.org/cl/155077
parent 50c04132
...@@ -37,7 +37,6 @@ var encodeT = []EncodeT{ ...@@ -37,7 +37,6 @@ var encodeT = []EncodeT{
EncodeT{1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, EncodeT{1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
} }
// Test basic encode/decode routines for unsigned integers // Test basic encode/decode routines for unsigned integers
func TestUintCodec(t *testing.T) { func TestUintCodec(t *testing.T) {
b := new(bytes.Buffer); b := new(bytes.Buffer);
...@@ -592,9 +591,15 @@ func TestEndToEnd(t *testing.T) { ...@@ -592,9 +591,15 @@ func TestEndToEnd(t *testing.T) {
t: &T2{"this is T2"}, t: &T2{"this is T2"},
}; };
b := new(bytes.Buffer); b := new(bytes.Buffer);
encode(b, t1); err := NewEncoder(b).Encode(t1);
if err != nil {
t.Error("encode:", err)
}
var _t1 T1; var _t1 T1;
decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1); err = NewDecoder(b).Decode(&_t1);
if err != nil {
t.Fatal("decode:", err)
}
if !reflect.DeepEqual(t1, &_t1) { if !reflect.DeepEqual(t1, &_t1) {
t.Errorf("encode expected %v got %v", *t1, _t1) t.Errorf("encode expected %v got %v", *t1, _t1)
} }
...@@ -610,8 +615,9 @@ func TestOverflow(t *testing.T) { ...@@ -610,8 +615,9 @@ func TestOverflow(t *testing.T) {
} }
var it inputT; var it inputT;
var err os.Error; var err os.Error;
id := getTypeInfoNoError(reflect.Typeof(it)).id;
b := new(bytes.Buffer); b := new(bytes.Buffer);
enc := NewEncoder(b);
dec := NewDecoder(b);
// int8 // int8
b.Reset(); b.Reset();
...@@ -623,8 +629,8 @@ func TestOverflow(t *testing.T) { ...@@ -623,8 +629,8 @@ func TestOverflow(t *testing.T) {
mini int8; mini int8;
} }
var o1 outi8; var o1 outi8;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o1); err = dec.Decode(&o1);
if err == nil || err.String() != `value for "maxi" out of range` { if err == nil || err.String() != `value for "maxi" out of range` {
t.Error("wrong overflow error for int8:", err) t.Error("wrong overflow error for int8:", err)
} }
...@@ -632,8 +638,8 @@ func TestOverflow(t *testing.T) { ...@@ -632,8 +638,8 @@ func TestOverflow(t *testing.T) {
mini: math.MinInt8 - 1, mini: math.MinInt8 - 1,
}; };
b.Reset(); b.Reset();
encode(b, it); enc.Encode(it);
err = decode(b, id, &o1); err = dec.Decode(&o1);
if err == nil || err.String() != `value for "mini" out of range` { if err == nil || err.String() != `value for "mini" out of range` {
t.Error("wrong underflow error for int8:", err) t.Error("wrong underflow error for int8:", err)
} }
...@@ -648,8 +654,8 @@ func TestOverflow(t *testing.T) { ...@@ -648,8 +654,8 @@ func TestOverflow(t *testing.T) {
mini int16; mini int16;
} }
var o2 outi16; var o2 outi16;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o2); err = dec.Decode(&o2);
if err == nil || err.String() != `value for "maxi" out of range` { if err == nil || err.String() != `value for "maxi" out of range` {
t.Error("wrong overflow error for int16:", err) t.Error("wrong overflow error for int16:", err)
} }
...@@ -657,8 +663,8 @@ func TestOverflow(t *testing.T) { ...@@ -657,8 +663,8 @@ func TestOverflow(t *testing.T) {
mini: math.MinInt16 - 1, mini: math.MinInt16 - 1,
}; };
b.Reset(); b.Reset();
encode(b, it); enc.Encode(it);
err = decode(b, id, &o2); err = dec.Decode(&o2);
if err == nil || err.String() != `value for "mini" out of range` { if err == nil || err.String() != `value for "mini" out of range` {
t.Error("wrong underflow error for int16:", err) t.Error("wrong underflow error for int16:", err)
} }
...@@ -673,8 +679,8 @@ func TestOverflow(t *testing.T) { ...@@ -673,8 +679,8 @@ func TestOverflow(t *testing.T) {
mini int32; mini int32;
} }
var o3 outi32; var o3 outi32;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o3); err = dec.Decode(&o3);
if err == nil || err.String() != `value for "maxi" out of range` { if err == nil || err.String() != `value for "maxi" out of range` {
t.Error("wrong overflow error for int32:", err) t.Error("wrong overflow error for int32:", err)
} }
...@@ -682,8 +688,8 @@ func TestOverflow(t *testing.T) { ...@@ -682,8 +688,8 @@ func TestOverflow(t *testing.T) {
mini: math.MinInt32 - 1, mini: math.MinInt32 - 1,
}; };
b.Reset(); b.Reset();
encode(b, it); enc.Encode(it);
err = decode(b, id, &o3); err = dec.Decode(&o3);
if err == nil || err.String() != `value for "mini" out of range` { if err == nil || err.String() != `value for "mini" out of range` {
t.Error("wrong underflow error for int32:", err) t.Error("wrong underflow error for int32:", err)
} }
...@@ -697,8 +703,8 @@ func TestOverflow(t *testing.T) { ...@@ -697,8 +703,8 @@ func TestOverflow(t *testing.T) {
maxu uint8; maxu uint8;
} }
var o4 outu8; var o4 outu8;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o4); err = dec.Decode(&o4);
if err == nil || err.String() != `value for "maxu" out of range` { if err == nil || err.String() != `value for "maxu" out of range` {
t.Error("wrong overflow error for uint8:", err) t.Error("wrong overflow error for uint8:", err)
} }
...@@ -712,8 +718,8 @@ func TestOverflow(t *testing.T) { ...@@ -712,8 +718,8 @@ func TestOverflow(t *testing.T) {
maxu uint16; maxu uint16;
} }
var o5 outu16; var o5 outu16;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o5); err = dec.Decode(&o5);
if err == nil || err.String() != `value for "maxu" out of range` { if err == nil || err.String() != `value for "maxu" out of range` {
t.Error("wrong overflow error for uint16:", err) t.Error("wrong overflow error for uint16:", err)
} }
...@@ -727,8 +733,8 @@ func TestOverflow(t *testing.T) { ...@@ -727,8 +733,8 @@ func TestOverflow(t *testing.T) {
maxu uint32; maxu uint32;
} }
var o6 outu32; var o6 outu32;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o6); err = dec.Decode(&o6);
if err == nil || err.String() != `value for "maxu" out of range` { if err == nil || err.String() != `value for "maxu" out of range` {
t.Error("wrong overflow error for uint32:", err) t.Error("wrong overflow error for uint32:", err)
} }
...@@ -743,8 +749,8 @@ func TestOverflow(t *testing.T) { ...@@ -743,8 +749,8 @@ func TestOverflow(t *testing.T) {
minf float32; minf float32;
} }
var o7 outf32; var o7 outf32;
encode(b, it); enc.Encode(it);
err = decode(b, id, &o7); err = dec.Decode(&o7);
if err == nil || err.String() != `value for "maxf" out of range` { if err == nil || err.String() != `value for "maxf" out of range` {
t.Error("wrong overflow error for float32:", err) t.Error("wrong overflow error for float32:", err)
} }
...@@ -761,9 +767,13 @@ func TestNesting(t *testing.T) { ...@@ -761,9 +767,13 @@ func TestNesting(t *testing.T) {
rt.next = new(RT); rt.next = new(RT);
rt.next.a = "level2"; rt.next.a = "level2";
b := new(bytes.Buffer); b := new(bytes.Buffer);
encode(b, rt); NewEncoder(b).Encode(rt);
var drt RT; var drt RT;
decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt); dec := NewDecoder(b);
err := dec.Decode(&drt);
if err != nil {
t.Errorf("decoder error:", err)
}
if drt.a != rt.a { if drt.a != rt.a {
t.Errorf("nesting: encode expected %v got %v", *rt, drt) t.Errorf("nesting: encode expected %v got %v", *rt, drt)
} }
...@@ -809,10 +819,11 @@ func TestAutoIndirection(t *testing.T) { ...@@ -809,10 +819,11 @@ func TestAutoIndirection(t *testing.T) {
**t1.d = new(int); **t1.d = new(int);
***t1.d = 17777; ***t1.d = 17777;
b := new(bytes.Buffer); b := new(bytes.Buffer);
encode(b, t1); enc := NewEncoder(b);
enc.Encode(t1);
dec := NewDecoder(b);
var t0 T0; var t0 T0;
t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id; dec.Decode(&t0);
decode(b, t0Id, &t0);
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0) t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0)
} }
...@@ -830,9 +841,9 @@ func TestAutoIndirection(t *testing.T) { ...@@ -830,9 +841,9 @@ func TestAutoIndirection(t *testing.T) {
**t2.a = new(int); **t2.a = new(int);
***t2.a = 17; ***t2.a = 17;
b.Reset(); b.Reset();
encode(b, t2); enc.Encode(t2);
t0 = T0{}; t0 = T0{};
decode(b, t0Id, &t0); dec.Decode(&t0);
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0) t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0)
} }
...@@ -840,32 +851,30 @@ func TestAutoIndirection(t *testing.T) { ...@@ -840,32 +851,30 @@ func TestAutoIndirection(t *testing.T) {
// Now transfer t0 into t1 // Now transfer t0 into t1
t0 = T0{17, 177, 1777, 17777}; t0 = T0{17, 177, 1777, 17777};
b.Reset(); b.Reset();
encode(b, t0); enc.Encode(t0);
t1 = T1{}; t1 = T1{};
t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id; dec.Decode(&t1);
decode(b, t1Id, &t1);
if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 { if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d) t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d)
} }
// Now transfer t0 into t2 // Now transfer t0 into t2
b.Reset(); b.Reset();
encode(b, t0); enc.Encode(t0);
t2 = T2{}; t2 = T2{};
t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id; dec.Decode(&t2);
decode(b, t2Id, &t2);
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d) t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
} }
// Now do t2 again but without pre-allocated pointers. // Now do t2 again but without pre-allocated pointers.
b.Reset(); b.Reset();
encode(b, t0); enc.Encode(t0);
***t2.a = 0; ***t2.a = 0;
**t2.b = 0; **t2.b = 0;
*t2.c = 0; *t2.c = 0;
t2.d = 0; t2.d = 0;
decode(b, t2Id, &t2); dec.Decode(&t2);
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d) t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
} }
...@@ -889,11 +898,14 @@ func TestReorderedFields(t *testing.T) { ...@@ -889,11 +898,14 @@ func TestReorderedFields(t *testing.T) {
rt0.b = "hello"; rt0.b = "hello";
rt0.c = 3.14159; rt0.c = 3.14159;
b := new(bytes.Buffer); b := new(bytes.Buffer);
encode(b, rt0); NewEncoder(b).Encode(rt0);
rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id; dec := NewDecoder(b);
var rt1 RT1; var rt1 RT1;
// Wire type is RT0, local type is RT1. // Wire type is RT0, local type is RT1.
decode(b, rt0Id, &rt1); err := dec.Decode(&rt1);
if err != nil {
t.Error("decode error:", err)
}
if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c { if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c {
t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1) t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1)
} }
...@@ -927,11 +939,11 @@ func TestIgnoredFields(t *testing.T) { ...@@ -927,11 +939,11 @@ func TestIgnoredFields(t *testing.T) {
it0.ignore_i = &RT1{3.1, "hi", 7, "hello"}; it0.ignore_i = &RT1{3.1, "hi", 7, "hello"};
b := new(bytes.Buffer); b := new(bytes.Buffer);
encode(b, it0); NewEncoder(b).Encode(it0);
rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id; dec := NewDecoder(b);
var rt1 RT1; var rt1 RT1;
// Wire type is IT0, local type is RT1. // Wire type is IT0, local type is RT1.
err := decode(b, rt0Id, &rt1); err := dec.Decode(&rt1);
if err != nil { if err != nil {
t.Error("error: ", err) t.Error("error: ", err)
} }
......
...@@ -348,7 +348,7 @@ func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { ...@@ -348,7 +348,7 @@ func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
// Execution engine // Execution engine
// The encoder engine is an array of instructions indexed by field number of the incoming // The encoder engine is an array of instructions indexed by field number of the incoming
// data. It is executed with random access according to field number. // decoder. It is executed with random access according to field number.
type decEngine struct { type decEngine struct {
instr []decInstr; instr []decInstr;
numInstr int; // the number of active instructions numInstr int; // the number of active instructions
...@@ -515,7 +515,7 @@ var decIgnoreOpMap = map[typeId]decOp{ ...@@ -515,7 +515,7 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and // Return the decoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) { func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
typ, indir := indirect(rt); typ, indir := indirect(rt);
op, ok := decOpMap[reflect.Typeof(typ)]; op, ok := decOpMap[reflect.Typeof(typ)];
if !ok { if !ok {
...@@ -528,7 +528,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error ...@@ -528,7 +528,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
break; break;
} }
elemId := wireId.gobType().(*sliceType).Elem; elemId := wireId.gobType().(*sliceType).Elem;
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name); elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -540,7 +540,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error ...@@ -540,7 +540,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
case *reflect.ArrayType: case *reflect.ArrayType:
name = "element of " + name; name = "element of " + name;
elemId := wireId.gobType().(*arrayType).Elem; elemId := wireId.gobType().(*arrayType).Elem;
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name); elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -551,7 +551,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error ...@@ -551,7 +551,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
case *reflect.StructType: case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enginePtr, err := getDecEnginePtr(wireId, typ); enginePtr, err := dec.getDecEnginePtr(wireId, typ);
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -568,14 +568,14 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error ...@@ -568,14 +568,14 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
} }
// Return the decoding op for a field that has no destination. // Return the decoding op for a field that has no destination.
func decIgnoreOpFor(wireId typeId) (decOp, os.Error) { func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
op, ok := decIgnoreOpMap[wireId]; op, ok := decIgnoreOpMap[wireId];
if !ok { if !ok {
// Special cases // Special cases
switch t := wireId.gobType().(type) { switch t := wireId.gobType().(type) {
case *sliceType: case *sliceType:
elemId := wireId.gobType().(*sliceType).Elem; elemId := wireId.gobType().(*sliceType).Elem;
elemOp, err := decIgnoreOpFor(elemId); elemOp, err := dec.decIgnoreOpFor(elemId);
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -585,7 +585,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -585,7 +585,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
case *arrayType: case *arrayType:
elemId := wireId.gobType().(*arrayType).Elem; elemId := wireId.gobType().(*arrayType).Elem;
elemOp, err := decIgnoreOpFor(elemId); elemOp, err := dec.decIgnoreOpFor(elemId);
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -595,7 +595,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) { ...@@ -595,7 +595,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
case *structType: case *structType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
enginePtr, err := getIgnoreEnginePtr(wireId); enginePtr, err := dec.getIgnoreEnginePtr(wireId);
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -676,11 +676,18 @@ func compatibleType(fr reflect.Type, fw typeId) bool { ...@@ -676,11 +676,18 @@ func compatibleType(fr reflect.Type, fw typeId) bool {
return true; return true;
} }
func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
srt, ok1 := rt.(*reflect.StructType); srt, ok1 := rt.(*reflect.StructType);
wireStruct, ok2 := wireId.gobType().(*structType); var wireStruct *structType;
if !ok1 || !ok2 { // Builtin types can come from global pool; the rest must be defined by the decoder
return nil, errNotStruct if t, ok := builtinIdToType[remoteId]; ok {
wireStruct = t.(*structType)
} else {
w, ok2 := dec.wireType[remoteId];
if !ok1 || !ok2 {
return nil, errNotStruct
}
wireStruct = w.s;
} }
engine = new(decEngine); engine = new(decEngine);
engine.instr = make([]decInstr, len(wireStruct.field)); engine.instr = make([]decInstr, len(wireStruct.field));
...@@ -692,7 +699,7 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error ...@@ -692,7 +699,7 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
ovfl := overflow(wireField.name); ovfl := overflow(wireField.name);
// TODO(r): anonymous names // TODO(r): anonymous names
if !present { if !present {
op, err := decIgnoreOpFor(wireField.id); op, err := dec.decIgnoreOpFor(wireField.id);
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -700,10 +707,10 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error ...@@ -700,10 +707,10 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
continue; continue;
} }
if !compatibleType(localField.Type, wireField.id) { if !compatibleType(localField.Type, wireField.id) {
details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + wireId.Name(); details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + remoteId.Name();
return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details); return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details);
} }
op, indir, err := decOpFor(wireField.id, localField.Type, localField.Name); op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name);
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -713,23 +720,19 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error ...@@ -713,23 +720,19 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
return; return;
} }
var decoderCache = make(map[reflect.Type]map[typeId]**decEngine) func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
var ignorerCache = make(map[typeId]**decEngine) decoderMap, ok := dec.decoderCache[rt];
// typeLock must be held.
func getDecEnginePtr(wireId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
decoderMap, ok := decoderCache[rt];
if !ok { if !ok {
decoderMap = make(map[typeId]**decEngine); decoderMap = make(map[typeId]**decEngine);
decoderCache[rt] = decoderMap; dec.decoderCache[rt] = decoderMap;
} }
if enginePtr, ok = decoderMap[wireId]; !ok { if enginePtr, ok = decoderMap[remoteId]; !ok {
// To handle recursive types, mark this engine as underway before compiling. // To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine); enginePtr = new(*decEngine);
decoderMap[wireId] = enginePtr; decoderMap[remoteId] = enginePtr;
*enginePtr, err = compileDec(wireId, rt); *enginePtr, err = dec.compileDec(remoteId, rt);
if err != nil { if err != nil {
decoderMap[wireId] = nil, false decoderMap[remoteId] = nil, false
} }
} }
return; return;
...@@ -740,35 +743,28 @@ type emptyStruct struct{} ...@@ -740,35 +743,28 @@ type emptyStruct struct{}
var emptyStructType = reflect.Typeof(emptyStruct{}) var emptyStructType = reflect.Typeof(emptyStruct{})
// typeLock must be held. func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
var ok bool; var ok bool;
if enginePtr, ok = ignorerCache[wireId]; !ok { if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
// To handle recursive types, mark this engine as underway before compiling. // To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine); enginePtr = new(*decEngine);
ignorerCache[wireId] = enginePtr; dec.ignorerCache[wireId] = enginePtr;
*enginePtr, err = compileDec(wireId, emptyStructType); *enginePtr, err = dec.compileDec(wireId, emptyStructType);
if err != nil { if err != nil {
ignorerCache[wireId] = nil, false dec.ignorerCache[wireId] = nil, false
} }
} }
return; return;
} }
func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error { func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
// Dereference down to the underlying struct type. // Dereference down to the underlying struct type.
rt, indir := indirect(reflect.Typeof(e)); rt, indir := indirect(reflect.Typeof(e));
st, ok := rt.(*reflect.StructType); st, ok := rt.(*reflect.StructType);
if !ok { if !ok {
return os.ErrorString("gob: decode can't handle " + rt.String()) return os.ErrorString("gob: decode can't handle " + rt.String())
} }
typeLock.Lock(); enginePtr, err := dec.getDecEnginePtr(wireId, rt);
if _, ok := idToType[wireId]; !ok {
typeLock.Unlock();
return errBadType;
}
enginePtr, err := getDecEnginePtr(wireId, rt);
typeLock.Unlock();
if err != nil { if err != nil {
return err return err
} }
...@@ -777,7 +773,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error { ...@@ -777,7 +773,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
name := rt.Name(); name := rt.Name();
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name); return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name);
} }
return decodeStruct(engine, st, b, uintptr(reflect.NewValue(e).Addr()), indir); return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir);
} }
func init() { func init() {
......
...@@ -8,17 +8,20 @@ import ( ...@@ -8,17 +8,20 @@ import (
"bytes"; "bytes";
"io"; "io";
"os"; "os";
"reflect";
"sync"; "sync";
) )
// A Decoder manages the receipt of type and data information read from the // A Decoder manages the receipt of type and data information read from the
// remote side of a connection. // remote side of a connection.
type Decoder struct { type Decoder struct {
mutex sync.Mutex; // each item must be received atomically mutex sync.Mutex; // each item must be received atomically
r io.Reader; // source of the data r io.Reader; // source of the data
seen map[typeId]*wireType; // which types we've already seen described wireType map[typeId]*wireType; // map from remote ID to local description
state *decodeState; // reads data from in-memory buffer decoderCache map[reflect.Type]map[typeId]**decEngine; // cache of compiled engines
countState *decodeState; // reads counts from wire ignorerCache map[typeId]**decEngine; // ditto for ignored objects
state *decodeState; // reads data from in-memory buffer
countState *decodeState; // reads counts from wire
buf []byte; buf []byte;
oneByte []byte; oneByte []byte;
} }
...@@ -27,8 +30,10 @@ type Decoder struct { ...@@ -27,8 +30,10 @@ type Decoder struct {
func NewDecoder(r io.Reader) *Decoder { func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder); dec := new(Decoder);
dec.r = r; dec.r = r;
dec.seen = make(map[typeId]*wireType); dec.wireType = make(map[typeId]*wireType);
dec.state = newDecodeState(nil); // buffer set in Decode(); rest is unimportant dec.state = newDecodeState(nil); // buffer set in Decode(); rest is unimportant
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine);
dec.ignorerCache = make(map[typeId]**decEngine);
dec.oneByte = make([]byte, 1); dec.oneByte = make([]byte, 1);
return dec; return dec;
...@@ -36,16 +41,16 @@ func NewDecoder(r io.Reader) *Decoder { ...@@ -36,16 +41,16 @@ func NewDecoder(r io.Reader) *Decoder {
func (dec *Decoder) recvType(id typeId) { func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error // Have we already seen this type? That's an error
if _, alreadySeen := dec.seen[id]; alreadySeen { if _, alreadySeen := dec.wireType[id]; alreadySeen {
dec.state.err = os.ErrorString("gob: duplicate type received"); dec.state.err = os.ErrorString("gob: duplicate type received");
return; return;
} }
// Type: // Type:
wire := new(wireType); wire := new(wireType);
decode(dec.state.b, tWireType, wire); dec.state.err = dec.decode(tWireType, wire);
// Remember we've seen this type. // Remember we've seen this type.
dec.seen[id] = wire; dec.wireType[id] = wire;
} }
// Decode reads the next value from the connection and stores // Decode reads the next value from the connection and stores
...@@ -97,7 +102,13 @@ func (dec *Decoder) Decode(e interface{}) os.Error { ...@@ -97,7 +102,13 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
} }
// No, it's a value. // No, it's a value.
dec.state.err = decode(dec.state.b, id, e); // Make sure the type has been defined already.
_, ok := dec.wireType[id];
if !ok {
dec.state.err = errBadType;
break;
}
dec.state.err = dec.decode(id, e);
break; break;
} }
return dec.state.err; return dec.state.err;
......
...@@ -235,19 +235,15 @@ func (enc *Encoder) send() { ...@@ -235,19 +235,15 @@ func (enc *Encoder) send() {
enc.w.Write(enc.buf[0:total]); enc.w.Write(enc.buf[0:total]);
} }
func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) { func (enc *Encoder) sendType(origt reflect.Type) {
// Drill down to the base type. // Drill down to the base type.
rt, _ := indirect(origt); rt, _ := indirect(origt);
// We only send structs - everything else is basic or an error // We only send structs - everything else is basic or an error
switch rt.(type) { switch rt := rt.(type) {
default: default:
// Basic types do not need to be described, but if this is a top-level // Basic types do not need to be described.
// type, it's a user error, at least for now. return
if topLevel {
enc.badType(rt)
}
return;
case *reflect.StructType: case *reflect.StructType:
// Structs do need to be described. // Structs do need to be described.
break break
...@@ -255,10 +251,9 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) { ...@@ -255,10 +251,9 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
// Probably a bad field in a struct. // Probably a bad field in a struct.
enc.badType(rt); enc.badType(rt);
return; return;
case *reflect.ArrayType, *reflect.SliceType: // Array and slice types are not sent, only their element types.
// Array and slice types are not sent, only their element types. case reflect.ArrayOrSliceType:
// If we see one here it's user error; probably a bad top-level value. enc.sendType(rt.Elem());
enc.badType(rt);
return; return;
} }
...@@ -289,7 +284,7 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) { ...@@ -289,7 +284,7 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
// Now send the inner types // Now send the inner types
st := rt.(*reflect.StructType); st := rt.(*reflect.StructType);
for i := 0; i < st.NumField(); i++ { for i := 0; i < st.NumField(); i++ {
enc.sendType(st.Field(i).Type, false) enc.sendType(st.Field(i).Type)
} }
return; return;
} }
...@@ -301,6 +296,12 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -301,6 +296,12 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
panicln("Encoder: buffer not empty") panicln("Encoder: buffer not empty")
} }
rt, _ := indirect(reflect.Typeof(e)); rt, _ := indirect(reflect.Typeof(e));
// Must be a struct
if _, ok := rt.(*reflect.StructType); !ok {
enc.badType(rt);
return enc.state.err;
}
// Make sure we're single-threaded through here. // Make sure we're single-threaded through here.
enc.mutex.Lock(); enc.mutex.Lock();
...@@ -310,7 +311,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -310,7 +311,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// First, have we already sent this type? // First, have we already sent this type?
if _, alreadySent := enc.sent[rt]; !alreadySent { if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it. // No, so send it.
enc.sendType(rt, true); enc.sendType(rt);
if enc.state.err != nil { if enc.state.err != nil {
enc.state.b.Reset(); enc.state.b.Reset();
enc.countState.b.Reset(); enc.countState.b.Reset();
......
...@@ -32,122 +32,10 @@ type ET3 struct { ...@@ -32,122 +32,10 @@ type ET3 struct {
// Like ET1 but with a different type for a field // Like ET1 but with a different type for a field
type ET4 struct { type ET4 struct {
a int; a int;
et2 *ET1; et2 float;
next int; next int;
} }
func TestBasicEncoder(t *testing.T) {
b := new(bytes.Buffer);
enc := NewEncoder(b);
et1 := new(ET1);
et1.a = 7;
et1.et2 = new(ET2);
enc.Encode(et1);
if enc.state.err != nil {
t.Error("encoder fail:", enc.state.err)
}
// Decode the result by hand to verify;
state := newDecodeState(b);
// The output should be:
// 0) The length, 38.
length := decodeUint(state);
if length != 38 {
t.Fatal("0. expected length 38; got", length)
}
// 1) -7: the type id of ET1
id1 := decodeInt(state);
if id1 >= 0 {
t.Fatal("expected ET1 negative id; got", id1)
}
// 2) The wireType for ET1
wire1 := new(wireType);
err := decode(b, tWireType, wire1);
if err != nil {
t.Fatal("error decoding ET1 type:", err)
}
info := getTypeInfoNoError(reflect.Typeof(ET1{}));
trueWire1 := &wireType{s: info.id.gobType().(*structType)};
if !reflect.DeepEqual(wire1, trueWire1) {
t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1)
}
// 3) The length, 21.
length = decodeUint(state);
if length != 21 {
t.Fatal("3. expected length 21; got", length)
}
// 4) -8: the type id of ET2
id2 := decodeInt(state);
if id2 >= 0 {
t.Fatal("expected ET2 negative id; got", id2)
}
// 5) The wireType for ET2
wire2 := new(wireType);
err = decode(b, tWireType, wire2);
if err != nil {
t.Fatal("error decoding ET2 type:", err)
}
info = getTypeInfoNoError(reflect.Typeof(ET2{}));
trueWire2 := &wireType{s: info.id.gobType().(*structType)};
if !reflect.DeepEqual(wire2, trueWire2) {
t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2)
}
// 6) The length, 6.
length = decodeUint(state);
if length != 6 {
t.Fatal("6. expected length 6; got", length)
}
// 7) The type id for the et1 value
newId1 := decodeInt(state);
if newId1 != -id1 {
t.Fatal("expected Et1 id", -id1, "got", newId1)
}
// 8) The value of et1
newEt1 := new(ET1);
et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id;
err = decode(b, et1Id, newEt1);
if err != nil {
t.Fatal("error decoding ET1 value:", err)
}
if !reflect.DeepEqual(et1, newEt1) {
t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
}
// 9) EOF
if b.Len() != 0 {
t.Error("not at eof;", b.Len(), "bytes left")
}
// Now do it again. This time we should see only the type id and value.
b.Reset();
enc.Encode(et1);
if enc.state.err != nil {
t.Error("2nd round: encoder fail:", enc.state.err)
}
// The length.
length = decodeUint(state);
if length != 6 {
t.Fatal("6. expected length 6; got", length)
}
// 5a) The type id for the et1 value
newId1 = decodeInt(state);
if newId1 != -id1 {
t.Fatal("2nd round: expected Et1 id", -id1, "got", newId1)
}
// 6a) The value of et1
newEt1 = new(ET1);
err = decode(b, et1Id, newEt1);
if err != nil {
t.Fatal("2nd round: error decoding ET1 value:", err)
}
if !reflect.DeepEqual(et1, newEt1) {
t.Fatalf("2nd round: invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
}
// 7a) EOF
if b.Len() != 0 {
t.Error("2nd round: not at eof;", b.Len(), "bytes left")
}
}
func TestEncoderDecoder(t *testing.T) { func TestEncoderDecoder(t *testing.T) {
b := new(bytes.Buffer); b := new(bytes.Buffer);
enc := NewEncoder(b); enc := NewEncoder(b);
...@@ -215,7 +103,7 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) { ...@@ -215,7 +103,7 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) {
t.Error("expected error for", msg) t.Error("expected error for", msg)
} }
if !shouldFail && (dec.state.err != nil) { if !shouldFail && (dec.state.err != nil) {
t.Error("unexpected error for", msg) t.Error("unexpected error for", msg, dec.state.err)
} }
} }
......
...@@ -48,6 +48,7 @@ type gobType interface { ...@@ -48,6 +48,7 @@ type gobType interface {
var types = make(map[reflect.Type]gobType) var types = make(map[reflect.Type]gobType)
var idToType = make(map[typeId]gobType) var idToType = make(map[typeId]gobType)
var builtinIdToType map[typeId]gobType // set in init() after builtins are established
func setTypeId(typ gobType) { func setTypeId(typ gobType) {
nextId++; nextId++;
...@@ -104,6 +105,10 @@ func init() { ...@@ -104,6 +105,10 @@ func init() {
checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id); checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id);
checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id); checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id);
checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id); checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id);
builtinIdToType = make(map[typeId]gobType);
for k, v := range idToType {
builtinIdToType[k] = v
}
} }
// Array type // Array type
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment