Commit 4bbf5dc8 authored by Levin Zimmermann's avatar Levin Zimmermann

go/neo/proto/msgpack: Fix decoding of unset IdTime

The data type of IdTime is 'Optional[float]' [1], [2]. Before this patch the
msgpack decoder could only decode 'IdTime' in case its data type was
'float'. Now it also supports the decoding of IdTime in case it is NIL.

Besides the changes in 'protogen.go', this patch also relaxes the
overflow check test in 'proto_test.go', which fails in its previous form
after adjusting the decoding. However I don't think the exact error
type should matter here: in real-usage cases we don't mind about the
particular error type [3].

[1] See here, the fifth argument is IdTime:

    https://lab.nexedi.com/nexedi/neoppod/-/blob/3ddb6663/neo/master/handlers/identification.py#L26-27

    This is found to be 'Optional[float]':

    https://lab.nexedi.com/nexedi/neoppod/-/blob/3ddb6663/neo/tests/protocol#L98

[2] This seems to be true for both, pre-msgpack and post-msgpack
    protocol, because in the pre-msgpack NEO/go there is already this
    note:

    https://lab.nexedi.com/kirr/neo/-/blob/1ad088c8/go/neo/proto/proto.go#L352-357

[3] See https://lab.nexedi.com/kirr/neo/-/blob/6fb93a60/go/neo/neonet/connection.go#L1531-1534
    and https://lab.nexedi.com/kirr/neo/-/blob/6fb93a60/go/neo/neonet/connection.go#L1588-1591
parent d5603f3d
...@@ -152,7 +152,7 @@ func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) { ...@@ -152,7 +152,7 @@ func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) {
// decode must detect buffer overflow // decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- { for l := len(encoded) - 1; l >= 0; l-- {
n, err = enc.MsgDecode(msg2, data[:l]) n, err = enc.MsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err != nil) {
t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l) t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l)
} }
......
...@@ -313,6 +313,7 @@ package proto ...@@ -313,6 +313,7 @@ package proto
import ( import (
"encoding/binary" "encoding/binary"
"math"
"reflect" "reflect"
"sort" "sort"
...@@ -1137,7 +1138,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1137,7 +1138,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
} }
// mgetfloat emits mgetfloat<size> // mgetfloat emits mgetfloat<size>
mgetfloat := func(size int) { mgetfloat := func(size int, optionalValue string) {
// delving into msgp - flush/prepare next site for overflow check // delving into msgp - flush/prepare next site for overflow check
d.overflowCheck() d.overflowCheck()
d.resetPos() d.resetPos()
...@@ -1146,7 +1147,15 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1146,7 +1147,15 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.emit("{") d.emit("{")
d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size) d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size)
d.emit("if err != nil {") d.emit("if err != nil {")
if optionalValue != "" {
d.emit(" tail, err = msgp.ReadNilBytes(data)")
d.emit(" if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" }")
d.emit("v = %v", optionalValue)
} else {
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
}
d.emit("}") d.emit("}")
d.emit("%s= %s", assignto, v) d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
...@@ -1154,6 +1163,14 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1154,6 +1163,14 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.emit("}") d.emit("}")
} }
// IdTime can be nil ('None' in py), in this case we use
// infinite -1, see
// https://lab.nexedi.com/kirr/neo/-/blob/1ad088c8/go/neo/proto/proto.go#L352-357
if typeName(userType) == "IdTime" {
mgetfloat(64, "math.Inf(-1)")
return
}
switch typ.Kind() { switch typ.Kind() {
case types.Bool: case types.Bool:
d.emit("switch op := msgpack.Op(data[%v]); op {", d.n) d.emit("switch op := msgpack.Op(data[%v]); op {", d.n)
...@@ -1175,7 +1192,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1175,7 +1192,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
case types.Uint32: mgetint("u", 32) case types.Uint32: mgetint("u", 32)
case types.Uint64: mgetint("u", 64) case types.Uint64: mgetint("u", 64)
case types.Float64: mgetfloat(64) case types.Float64: mgetfloat(64, "")
} }
} }
......
...@@ -6,6 +6,7 @@ package proto ...@@ -6,6 +6,7 @@ package proto
import ( import (
"encoding/binary" "encoding/binary"
"math"
"reflect" "reflect"
"sort" "sort"
...@@ -417,9 +418,13 @@ func (p *RequestIdentification) neoMsgDecodeM(data []byte) (int, error) { ...@@ -417,9 +418,13 @@ func (p *RequestIdentification) neoMsgDecodeM(data []byte) (int, error) {
} }
{ {
v, tail, err := msgp.ReadFloat64Bytes(data) v, tail, err := msgp.ReadFloat64Bytes(data)
if err != nil {
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("RequestIdentification.IdTime", err) return 0, mdecodeErr("RequestIdentification.IdTime", err)
} }
v = math.Inf(-1)
}
p.IdTime = IdTime(v) p.IdTime = IdTime(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
...@@ -1113,9 +1118,13 @@ func (p *NotifyNodeInformation) neoMsgDecodeM(data []byte) (int, error) { ...@@ -1113,9 +1118,13 @@ func (p *NotifyNodeInformation) neoMsgDecodeM(data []byte) (int, error) {
} }
{ {
v, tail, err := msgp.ReadFloat64Bytes(data) v, tail, err := msgp.ReadFloat64Bytes(data)
if err != nil {
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("NotifyNodeInformation.IdTime", err) return 0, mdecodeErr("NotifyNodeInformation.IdTime", err)
} }
v = math.Inf(-1)
}
p.IdTime = IdTime(v) p.IdTime = IdTime(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
...@@ -1209,9 +1218,13 @@ func (p *NotifyNodeInformation) neoMsgDecodeM(data []byte) (int, error) { ...@@ -1209,9 +1218,13 @@ func (p *NotifyNodeInformation) neoMsgDecodeM(data []byte) (int, error) {
data = data[3:] data = data[3:]
{ {
v, tail, err := msgp.ReadFloat64Bytes(data) v, tail, err := msgp.ReadFloat64Bytes(data)
if err != nil {
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("NotifyNodeInformation.IdTime", err) return 0, mdecodeErr("NotifyNodeInformation.IdTime", err)
} }
v = math.Inf(-1)
}
(*a).IdTime = IdTime(v) (*a).IdTime = IdTime(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
...@@ -6691,9 +6704,13 @@ func (p *AnswerNodeList) neoMsgDecodeM(data []byte) (int, error) { ...@@ -6691,9 +6704,13 @@ func (p *AnswerNodeList) neoMsgDecodeM(data []byte) (int, error) {
data = data[3:] data = data[3:]
{ {
v, tail, err := msgp.ReadFloat64Bytes(data) v, tail, err := msgp.ReadFloat64Bytes(data)
if err != nil {
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("AnswerNodeList.IdTime", err) return 0, mdecodeErr("AnswerNodeList.IdTime", err)
} }
v = math.Inf(-1)
}
(*a).IdTime = IdTime(v) (*a).IdTime = IdTime(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment