Commit b59f0b04 authored by Levin Zimmermann's avatar Levin Zimmermann

go/neo/proto/msgpack: Fix case where Compression == None

'Compression' is py type 'Optional[int]' [1]. Before this patch only 'int' was
supported. Now NEO/go also understands 'Compression' with value 'Nil'.

Without this patch, NEO/go client tests fail with

```
have: neos://127.0.0.1:19847,127.0.0.1:28658/1: load 7fffffffffffffff:0000000000000006: 127.0.0.1:39230 - 127.0.0.1:46143 .291: decode: decode: M: AnswerObject.Compression: msgp: attempted to decode type "nil" with method for "uint"
```

[1] See https://lab.nexedi.com/nexedi/neoppod/-/blob/e3cd5c5bf/neo/tests/protocol#L21
    The fourth argument is 'compression':
    https://lab.nexedi.com/nexedi/neoppod/-/blob/e3cd5c5bf/neo/storage/handlers/client.py#L77-78
parent 6c431a55
......@@ -1148,8 +1148,28 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
v = fmt.Sprintf("%s(v)", typeName(userType))
}
// optional emits optional value of int/float
optional := func(optionalValue string) {
if optionalValue != "" {
// Read%dBytes returns 'ErrShortBytes' in case prefix is
// correct float, but data is too short - catch this to return
// 'ErrDecodeOverflow' instead of type error.
d.emit(" err = mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" if err == ErrDecodeOverflow {")
d.emit(" return 0, err")
d.emit(" }")
d.emit(" tail, err = msgp.ReadNilBytes(data)")
d.emit(" if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" }")
d.emit(" v = %v", optionalValue)
} else {
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
}
}
// mgetint emits assignto = mget<kind>int<size>()
mgetint := func(kind string, size int) {
mgetint := func(kind string, size int, optionalValue string) {
// we are going to go into msgp - flush previously queued
// overflow checks; put place for next overflow check after
// msgp is done.
......@@ -1164,7 +1184,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.emit("{")
d.emit("v, tail, err := msgp.Read%snt%dBytes(data)", KI, size)
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
optional(optionalValue)
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
......@@ -1182,22 +1202,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.emit("{")
d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size)
d.emit("if err != nil {")
if optionalValue != "" {
// ReadFloat%dBytes returns 'ErrShortBytes' in case prefix is
// correct float, but data is too short - catch this to return
// 'ErrDecodeOverflow' instead of type error.
d.emit(" err = mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" if err == ErrDecodeOverflow {")
d.emit(" return 0, err")
d.emit(" }")
d.emit(" tail, err = msgp.ReadNilBytes(data)")
d.emit(" if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" }")
d.emit(" v = %v", optionalValue)
} else {
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
}
optional(optionalValue)
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
......@@ -1213,6 +1218,13 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
return
}
// Compression can be nil ('None'), this means the same as
// no compression ('py/NoneType.__bool__' is 'False').
if typeName(userType) == "Compression" {
mgetint("u", 64, "0")
return
}
switch typ.Kind() {
case types.Bool:
d.emit("switch op := msgpack.Op(data[%v]); op {", d.n)
......@@ -1224,15 +1236,15 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.n++
d.overflow.Add(1)
case types.Int8: mgetint("", 8)
case types.Int16: mgetint("", 16)
case types.Int32: mgetint("", 32)
case types.Int64: mgetint("", 64)
case types.Int8: mgetint("", 8, "")
case types.Int16: mgetint("", 16, "")
case types.Int32: mgetint("", 32, "")
case types.Int64: mgetint("", 64, "")
case types.Uint8: mgetint("u", 8)
case types.Uint16: mgetint("u", 16)
case types.Uint32: mgetint("u", 32)
case types.Uint64: mgetint("u", 64)
case types.Uint8: mgetint("u", 8, "")
case types.Uint16: mgetint("u", 16, "")
case types.Uint32: mgetint("u", 32, "")
case types.Uint64: mgetint("u", 64, "")
case types.Float64: mgetfloat(64, "")
}
......
......@@ -4967,9 +4967,17 @@ func (p *AnswerRebaseObject) neoMsgDecodeM(data []byte) (int, error) {
}
{
v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("AnswerRebaseObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil {
return 0, mdecodeErr("AnswerRebaseObject.Compression", err)
}
v = 0
}
p.Compression = Compression(v)
nread += uint64(len(data) - len(tail))
data = tail
......@@ -5180,9 +5188,17 @@ func (p *StoreObject) neoMsgDecodeM(data []byte) (int, error) {
}
{
v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("StoreObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil {
return 0, mdecodeErr("StoreObject.Compression", err)
}
v = 0
}
p.Compression = Compression(v)
nread += uint64(len(data) - len(tail))
data = tail
......@@ -6223,9 +6239,17 @@ func (p *AnswerObject) neoMsgDecodeM(data []byte) (int, error) {
}
{
v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("AnswerObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil {
return 0, mdecodeErr("AnswerObject.Compression", err)
}
v = 0
}
p.Compression = Compression(v)
nread += uint64(len(data) - len(tail))
data = tail
......@@ -13168,9 +13192,17 @@ func (p *AddObject) neoMsgDecodeM(data []byte) (int, error) {
}
{
v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("AddObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil {
return 0, mdecodeErr("AddObject.Compression", err)
}
v = 0
}
p.Compression = Compression(v)
nread += uint64(len(data) - len(tail))
data = tail
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment