diff --git a/go/neo/internal/msgpack/msgpack.go b/go/neo/internal/msgpack/msgpack.go index 190dc43a7b450667258de7d195f1517d2b1efca2..216cf45fc0b894122546566ddb6d746f82670c07 100644 --- a/go/neo/internal/msgpack/msgpack.go +++ b/go/neo/internal/msgpack/msgpack.go @@ -33,6 +33,8 @@ const ( FixMap_4 Op = 0b1000_0000 // 1000_XXXX FixArray_4 Op = 0b1001_0000 // 1001_XXXX + Nil Op = 0xc0 + False Op = 0xc2 True Op = 0xc3 @@ -285,3 +287,17 @@ func PutMapHead(data []byte, l int) (n int) { default: panic("len overflows uint32") } } + +// OptionalFloat64Size returns guessed size of encoded +// optional float64. It is a guess, because the function +// doesn't flag corrupt or empty data. +func OptionalFloat64Size(data []byte) uint64 { + if len(data) > 0 && data[0] == byte(Nil) { + return 1 + // valid float must be encoded in 9 byte + // 1: msgpack.Float64 header + // 2-9: <value64> + } else { + return 9 + } +} diff --git a/go/neo/internal/msgpack/msgpack_test.go b/go/neo/internal/msgpack/msgpack_test.go index 65b8cf9271abb8809807be8b9c9dcb8d1e125a4b..0dfdd47d414cdaaba4d3902b1bc3bacae5663ee9 100644 --- a/go/neo/internal/msgpack/msgpack_test.go +++ b/go/neo/internal/msgpack/msgpack_test.go @@ -185,3 +185,21 @@ func TestMap(t *testing.T) { test1(t, &tEncMapHead{}, tt.l, tt.encoded) } } + +func TestOptionalFloat64Size(t *testing.T) { + ts := func (wanted int, value []byte) { + v := OptionalFloat64Size(value) + if v != uint64(wanted) { + t.Errorf("want: %v; got: %v", wanted, v) + } + } + + // Overflow + ts(9, []byte {}) + // Optional value + ts(1, []byte {byte(Nil)}) + // valid float + ts(9, []byte {0xcb, 0x3f, 0xda, 0x63, 0x1f, 0x8a, 0x09, 0x02, 0xde}) + // anything else => error is raised later at overflow check + ts(9, []byte {1, 2}) +} diff --git a/go/neo/proto/proto_test.go b/go/neo/proto/proto_test.go index 53c0b51f74cbe5fdc6bc5cc6d7d99df9f5896dc1..8338fe6e6050d048ea18b1773bd0e5ba53c52204 100644 --- a/go/neo/proto/proto_test.go +++ b/go/neo/proto/proto_test.go @@ -352,7 +352,7 @@ func TestMsgMarshal(t *testing.T) { hex("ffffffffffffffff") + hex("00000000"), // M hex("92") + - hex("cb" + "fff0000000000000") + // XXX nan/-inf not handled yet + hex("c0") + hex("90"), }, diff --git a/go/neo/proto/protogen.go b/go/neo/proto/protogen.go index 071e3c29a543b1e0f318c796c357bde11f1be572..e39289355f4d896147e0a51d5060900ad0b6c4f0 100644 --- a/go/neo/proto/protogen.go +++ b/go/neo/proto/protogen.go @@ -972,6 +972,16 @@ func (s *sizerM) genBasic(path string, typ *types.Basic, userType types.Type) { return } + if typeName(userType) == "IdTime" { + // Unset IdTime must be NIL on the wire + s.emit("if %s == IdTimeNone {", path) + s.emit("%v += 1 // mnil", s.var_("size")) + s.emit("} else {") + s.emit("%v += 1+8 // mfloat64 + <value64>", s.var_("size")) + s.emit("}") + return + } + switch typ.Kind() { case types.Bool: s.size.Add(1) // mfalse|mtrue case types.Int8: s.size.AddExpr("msgpack.Int8Size(%s)", upath) @@ -1029,6 +1039,27 @@ func (e *encoderM) genBasic(path string, typ *types.Basic, userType types.Type) e.n = 0 } + // mputfloat64 emits float64 + mputfloat64 := func() { + // mfloat64 f64 + e.emit("data[%v] = byte(msgpack.Float64)", e.n); e.n++ + e.emit("float64_neoEncode(data[%v:], %s)", e.n, upath); e.n += 8 + } + + // IdTime can be nil ('None' in py), in this case we use + // infinite -1, see + // https://lab.nexedi.com/kirr/neo/-/blob/1ad088c8/go/neo/proto/proto.go#L352-357 + if typeName(userType) == "IdTime" { + e.emit("if %s == float64(IdTimeNone) {", upath) + e.emit("data[%v] = byte(msgpack.Nil)", e.n) // mnil + e.emit("data = data[%v:]", e.n + 1) + e.emit("} else {") + mputfloat64() + e.resetPos() + e.emit("}") + return + } + switch typ.Kind() { case types.Bool: e.emit("data[%v] = byte(msgpack.Bool(%s))", e.n, path) @@ -1045,9 +1076,7 @@ func (e *encoderM) genBasic(path string, typ *types.Basic, userType types.Type) case types.Uint64: mputint("u", 64) case types.Float64: - // mfloat64 f64 - e.emit("data[%v] = byte(msgpack.Float64)", e.n); e.n++ - e.emit("float64_neoEncode(data[%v:], %s)", e.n, upath); e.n += 8 + mputfloat64() } } @@ -1131,23 +1160,47 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty } // mgetfloat emits mgetfloat<size> - mgetfloat := func(size int) { + mgetfloat := func(size int, optionalValue string) { // delving into msgp - flush/prepare next site for overflow check d.overflowCheck() d.resetPos() defer d.overflowCheck() + _mgetfloat := func () { + d.emit("_v, tail, err := msgp.ReadFloat%dBytes(data)", size) + d.emit("if err != nil {") + d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) + d.emit("}") + d.emit("v = _v") + d.emit("data = tail") + } + d.emit("{") - d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size) - d.emit("if err != nil {") - d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) - d.emit("}") - d.emit("%s= %s", assignto, v) - d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) - d.emit("data = tail") + d.emit("var v float%d", size) + if optionalValue == "" { + _mgetfloat() + } else { + d.overflow.AddExpr("msgpack.OptionalFloat64Size(data)") + d.emit("if data[%v] == byte(msgpack.Nil) {", d.n) + d.emit("v = %v", optionalValue) + d.n += 1 + d.resetPos() + d.emit("} else {") + _mgetfloat() + d.emit("}") + } + d.emit("%s = %s", assignto, v) d.emit("}") } + // IdTime can be nil ('None' in py), in this case we use + // infinite -1, see + // https://lab.nexedi.com/kirr/neo/-/blob/1ad088c8/go/neo/proto/proto.go#L352-357 + if typeName(userType) == "IdTime" { + mgetfloat(64, "float64(IdTimeNone)") + return + } + switch typ.Kind() { case types.Bool: d.emit("switch op := msgpack.Op(data[%v]); op {", d.n) @@ -1169,7 +1222,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty case types.Uint32: mgetint("u", 32) case types.Uint64: mgetint("u", 64) - case types.Float64: mgetfloat(64) + case types.Float64: mgetfloat(64, "") } } @@ -1415,6 +1468,17 @@ func (s *sizerCommon) genSliceCommon(xs CodeGenCustomize, path string, typ *type s.size = curSize } +// data = data[n:] +// n = 0 +// +// XXX duplication wrt decoderCommon.resetPost +func (e *encoderCommon) resetPos() { + if e.n != 0 { + e.emit("data = data[%v:]", e.n) + e.n = 0 + } +} + func (e *encoderN) genSliceHead(path string, typ *types.Slice, obj types.Object) { e.emit("l := len(%s)", path) e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int]) diff --git a/go/neo/proto/zproto-marshal.go b/go/neo/proto/zproto-marshal.go index 8f63edb40d54210ce7de28e59776fd4e2fc69dbb..d4450d8355c8941dfe08d3fafe0909adf972b0f3 100644 --- a/go/neo/proto/zproto-marshal.go +++ b/go/neo/proto/zproto-marshal.go @@ -266,6 +266,11 @@ overflow: func (p *RequestIdentification) neoMsgEncodedLenM() int { var size int + if p.IdTime == IdTimeNone { + size += 1 // mnil + } else { + size += 1 + 8 // mfloat64 + <value64> + } for i := 0; i < len(p.DevPath); i++ { a := &p.DevPath[i] size += msgpack.BinHeadSize(len((*a))) + len((*a)) @@ -274,7 +279,7 @@ func (p *RequestIdentification) neoMsgEncodedLenM() int { a := &p.NewNID[i] size += msgpack.Uint32Size((*a)) } - return 14 + msgpack.Int32Size(int32(p.NID)) + msgpack.BinHeadSize(len(p.Address.Host)) + len(p.Address.Host) + msgpack.Uint16Size(p.Address.Port) + msgpack.BinHeadSize(len(p.ClusterName)) + len(p.ClusterName) + msgpack.ArrayHeadSize(len(p.DevPath)) + msgpack.ArrayHeadSize(len(p.NewNID)) + size + return 5 + msgpack.Int32Size(int32(p.NID)) + msgpack.BinHeadSize(len(p.Address.Host)) + len(p.Address.Host) + msgpack.Uint16Size(p.Address.Port) + msgpack.BinHeadSize(len(p.ClusterName)) + len(p.ClusterName) + msgpack.ArrayHeadSize(len(p.DevPath)) + msgpack.ArrayHeadSize(len(p.NewNID)) + size } func (p *RequestIdentification) neoMsgEncodeM(data []byte) { @@ -308,12 +313,18 @@ func (p *RequestIdentification) neoMsgEncodeM(data []byte) { copy(data, p.ClusterName) data = data[l:] } - data[0] = byte(msgpack.Float64) - float64_neoEncode(data[1:], float64(p.IdTime)) + if float64(p.IdTime) == float64(IdTimeNone) { + data[0] = byte(msgpack.Nil) + data = data[1:] + } else { + data[0] = byte(msgpack.Float64) + float64_neoEncode(data[1:], float64(p.IdTime)) + data = data[9:] + } { l := len(p.DevPath) - n := msgpack.PutArrayHead(data[9:], l) - data = data[9+n:] + n := msgpack.PutArrayHead(data[0:], l) + data = data[0+n:] for i := 0; i < l; i++ { a := &p.DevPath[i] { @@ -404,14 +415,24 @@ func (p *RequestIdentification) neoMsgDecodeM(data []byte) (int, error) { nread += uint64(len(data) - len(tail)) data = tail } + if uint64(len(data)) < msgpack.OptionalFloat64Size(data) { + goto overflow + } + nread += msgpack.OptionalFloat64Size(data) { - v, tail, err := msgp.ReadFloat64Bytes(data) - if err != nil { - return 0, mdecodeErr("RequestIdentification.IdTime", err) + var v float64 + if data[0] == byte(msgpack.Nil) { + v = float64(IdTimeNone) + data = data[1:] + } else { + _v, tail, err := msgp.ReadFloat64Bytes(data) + if err != nil { + return 0, mdecodeErr("RequestIdentification.IdTime", err) + } + v = _v + data = tail } p.IdTime = IdTime(v) - nread += uint64(len(data) - len(tail)) - data = tail } { l, tail, err := msgp.ReadArrayHeaderBytes(data) @@ -1027,21 +1048,37 @@ overflow: func (p *NotifyNodeInformation) neoMsgEncodedLenM() int { var size int + if p.IdTime == IdTimeNone { + size += 1 // mnil + } else { + size += 1 + 8 // mfloat64 + <value64> + } for i := 0; i < len(p.NodeList); i++ { a := &p.NodeList[i] + if (*a).IdTime == IdTimeNone { + size += 1 // mnil + } else { + size += 1 + 8 // mfloat64 + <value64> + } size += msgpack.BinHeadSize(len((*a).Addr.Host)) + len((*a).Addr.Host) + msgpack.Uint16Size((*a).Addr.Port) + msgpack.Int32Size(int32((*a).NID)) } - return 10 + msgpack.ArrayHeadSize(len(p.NodeList)) + len(p.NodeList)*17 + size + return 1 + msgpack.ArrayHeadSize(len(p.NodeList)) + len(p.NodeList)*8 + size } func (p *NotifyNodeInformation) neoMsgEncodeM(data []byte) { data[0] = byte(msgpack.FixArray_4 | 2) - data[1] = byte(msgpack.Float64) - float64_neoEncode(data[2:], float64(p.IdTime)) + if float64(p.IdTime) == float64(IdTimeNone) { + data[1] = byte(msgpack.Nil) + data = data[2:] + } else { + data[1] = byte(msgpack.Float64) + float64_neoEncode(data[2:], float64(p.IdTime)) + data = data[10:] + } { l := len(p.NodeList) - n := msgpack.PutArrayHead(data[10:], l) - data = data[10+n:] + n := msgpack.PutArrayHead(data[0:], l) + data = data[0+n:] for i := 0; i < l; i++ { a := &p.NodeList[i] data[0] = byte(msgpack.FixArray_4 | 5) @@ -1073,9 +1110,14 @@ func (p *NotifyNodeInformation) neoMsgEncodeM(data []byte) { panic("(*a).State: invalid NodeState enum value)") } data[2] = byte((*a).State) - data[3] = byte(msgpack.Float64) - float64_neoEncode(data[4:], float64((*a).IdTime)) - data = data[12:] + if float64((*a).IdTime) == float64(IdTimeNone) { + data[3] = byte(msgpack.Nil) + data = data[4:] + } else { + data[3] = byte(msgpack.Float64) + float64_neoEncode(data[4:], float64((*a).IdTime)) + data = data[12:] + } } } } @@ -1088,15 +1130,25 @@ func (p *NotifyNodeInformation) neoMsgDecodeM(data []byte) (int, error) { if op, opOk := msgpack.Op(data[0]), msgpack.FixArray_4|2; op != opOk { return 0, &mstructDecodeError{"NotifyNodeInformation", op, opOk} } + if uint64(len(data)) < msgpack.OptionalFloat64Size(data) { + goto overflow + } + nread += msgpack.OptionalFloat64Size(data) data = data[1:] { - v, tail, err := msgp.ReadFloat64Bytes(data) - if err != nil { - return 0, mdecodeErr("NotifyNodeInformation.IdTime", err) + var v float64 + if data[0] == byte(msgpack.Nil) { + v = float64(IdTimeNone) + data = data[1:] + } else { + _v, tail, err := msgp.ReadFloat64Bytes(data) + if err != nil { + return 0, mdecodeErr("NotifyNodeInformation.IdTime", err) + } + v = _v + data = tail } p.IdTime = IdTime(v) - nread += uint64(len(data) - len(tail)) - data = tail } { l, tail, err := msgp.ReadArrayHeaderBytes(data) @@ -1174,15 +1226,25 @@ func (p *NotifyNodeInformation) neoMsgDecodeM(data []byte) (int, error) { } (*a).State = NodeState(v) } + if uint64(len(data)) < msgpack.OptionalFloat64Size(data) { + goto overflow + } + nread += msgpack.OptionalFloat64Size(data) data = data[3:] { - v, tail, err := msgp.ReadFloat64Bytes(data) - if err != nil { - return 0, mdecodeErr("NotifyNodeInformation.IdTime", err) + var v float64 + if data[0] == byte(msgpack.Nil) { + v = float64(IdTimeNone) + data = data[1:] + } else { + _v, tail, err := msgp.ReadFloat64Bytes(data) + if err != nil { + return 0, mdecodeErr("NotifyNodeInformation.IdTime", err) + } + v = _v + data = tail } (*a).IdTime = IdTime(v) - nread += uint64(len(data) - len(tail)) - data = tail } } nread += uint64(l) * 8 @@ -6313,9 +6375,14 @@ func (p *AnswerNodeList) neoMsgEncodedLenM() int { var size int for i := 0; i < len(p.NodeList); i++ { a := &p.NodeList[i] + if (*a).IdTime == IdTimeNone { + size += 1 // mnil + } else { + size += 1 + 8 // mfloat64 + <value64> + } size += msgpack.BinHeadSize(len((*a).Addr.Host)) + len((*a).Addr.Host) + msgpack.Uint16Size((*a).Addr.Port) + msgpack.Int32Size(int32((*a).NID)) } - return 1 + msgpack.ArrayHeadSize(len(p.NodeList)) + len(p.NodeList)*17 + size + return 1 + msgpack.ArrayHeadSize(len(p.NodeList)) + len(p.NodeList)*8 + size } func (p *AnswerNodeList) neoMsgEncodeM(data []byte) { @@ -6355,9 +6422,14 @@ func (p *AnswerNodeList) neoMsgEncodeM(data []byte) { panic("(*a).State: invalid NodeState enum value)") } data[2] = byte((*a).State) - data[3] = byte(msgpack.Float64) - float64_neoEncode(data[4:], float64((*a).IdTime)) - data = data[12:] + if float64((*a).IdTime) == float64(IdTimeNone) { + data[3] = byte(msgpack.Nil) + data = data[4:] + } else { + data[3] = byte(msgpack.Float64) + float64_neoEncode(data[4:], float64((*a).IdTime)) + data = data[12:] + } } } } @@ -6447,15 +6519,25 @@ func (p *AnswerNodeList) neoMsgDecodeM(data []byte) (int, error) { } (*a).State = NodeState(v) } + if uint64(len(data)) < msgpack.OptionalFloat64Size(data) { + goto overflow + } + nread += msgpack.OptionalFloat64Size(data) data = data[3:] { - v, tail, err := msgp.ReadFloat64Bytes(data) - if err != nil { - return 0, mdecodeErr("AnswerNodeList.IdTime", err) + var v float64 + if data[0] == byte(msgpack.Nil) { + v = float64(IdTimeNone) + data = data[1:] + } else { + _v, tail, err := msgp.ReadFloat64Bytes(data) + if err != nil { + return 0, mdecodeErr("AnswerNodeList.IdTime", err) + } + v = _v + data = tail } (*a).IdTime = IdTime(v) - nread += uint64(len(data) - len(tail)) - data = tail } } nread += uint64(l) * 8