Commit 1cf2712f authored by Kirill Smelkov's avatar Kirill Smelkov

X on msgpack support

parent d2697535
...@@ -27,6 +27,7 @@ package xcontext ...@@ -27,6 +27,7 @@ package xcontext
import ( import (
"context" "context"
"errors" "errors"
"io"
) )
// Cancelled reports whether an error is due to a canceled context. // Cancelled reports whether an error is due to a canceled context.
...@@ -72,3 +73,41 @@ func WhenDone(ctx context.Context, f func()) func() { ...@@ -72,3 +73,41 @@ func WhenDone(ctx context.Context, f func()) func() {
close(done) close(done)
} }
} }
// WithCloseOnErrCancel closes c on ctx cancel while f is run, or if f returns with an error.
//
// It is usually handy to propagate cancellation to interrupt IO.
// XXX naming?
// XXX don't close on f return?
func WithCloseOnErrCancel(ctx context.Context, c io.Closer, f func() error) (err error) {
closed := false
fdone := make(chan error)
defer func() {
<-fdone // wait for f to complete
if err != nil {
if !closed {
c.Close()
}
}
}()
go func() (err error) {
defer func() {
fdone <- err
close(fdone)
}()
return f()
}()
select {
case <-ctx.Done():
c.Close() // interrupt IO
closed = true
return ctx.Err()
case err := <-fdone:
return err
}
}
...@@ -489,6 +489,7 @@ func withNEO(t *testing.T, f func(t *testing.T, nsrv NEOSrv, ndrv *Client), optv ...@@ -489,6 +489,7 @@ func withNEO(t *testing.T, f func(t *testing.T, nsrv NEOSrv, ndrv *Client), optv
withNEOSrv(t, func(t *testing.T, nsrv NEOSrv) { withNEOSrv(t, func(t *testing.T, nsrv NEOSrv) {
t.Helper() t.Helper()
X := xtesting.FatalIf(t) X := xtesting.FatalIf(t)
// TODO test for enc=(M|N) (XXX M|N only for NEO/go as NEO/py does not support autodetect)
ndrv, _, err := neoOpen(nsrv.URL(), ndrv, _, err := neoOpen(nsrv.URL(),
&zodb.DriverOptions{ReadOnly: true}); X(err) &zodb.DriverOptions{ReadOnly: true}); X(err)
defer func() { defer func() {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
// Package msgpack complements tinylib/msgp in providing runtime support for MessagePack.
//
// https://github.com/msgpack/msgpack/blob/master/spec.md
package msgpack
import (
"encoding/binary"
"math"
)
// Op represents a MessagePack opcode.
type Op byte
const (
FixMap_4 Op = 0b1000_0000 // 1000_XXXX
FixArray_4 Op = 0b1001_0000 // 1001_XXXX
False Op = 0xc2
True Op = 0xc3
Bin8 Op = 0xc4
Bin16 Op = 0xc5
Bin32 Op = 0xc6
Float32 Op = 0xca
Float64 Op = 0xcb
Uint8 Op = 0xcc
Uint16 Op = 0xcd
Uint32 Op = 0xce
Uint64 Op = 0xcf
Int8 Op = 0xd0
Int16 Op = 0xd1
Int32 Op = 0xd2
Int64 Op = 0xd3
FixExt1 Op = 0xd4
FixExt2 Op = 0xd5
FixExt4 Op = 0xd6
Array16 Op = 0xdc
Array32 Op = 0xdd
Map16 Op = 0xde
Map32 Op = 0xdf
)
// op converts Op into byte.
// it is used internally to make sure that only Op is put into encoded data.
func op(x Op) byte {
return byte(x)
}
// Bool returns op corresponding to bool value v.
func Bool(v bool) Op {
if v {
return True
} else {
return False
}
}
// u?intXSize(i) returns size needed to encode i.
func Int8Size (i int8) int { return Int64Size(int64(i)) }
func Int16Size(i int16) int { return Int64Size(int64(i)) }
func Int32Size(i int32) int { return Int64Size(int64(i)) }
func Uint8Size (i uint8) int { return Uint64Size(uint64(i)) }
func Uint16Size(i uint16) int { return Uint64Size(uint64(i)) }
func Uint32Size(i uint32) int { return Uint64Size(uint64(i)) }
// Putu?intX(data, i X) encodes i into data and returns encoded size.
func PutInt8 (data []byte, i int8) int { return PutInt64(data, int64(i)) }
func PutInt16(data []byte, i int16) int { return PutInt64(data, int64(i)) }
func PutInt32(data []byte, i int32) int { return PutInt64(data, int64(i)) }
func PutUint8 (data []byte, i uint8) int { return PutUint64(data, uint64(i)) }
func PutUint16(data []byte, i uint16) int { return PutUint64(data, uint64(i)) }
func PutUint32(data []byte, i uint32) int { return PutUint64(data, uint64(i)) }
func Int64Size(i int64) int {
switch {
case -32 <= i && i <= 0b0_1111111: return 1 // posfixint | negfixint
case int64(int8(i)) == i: return 1+1 // int8 + i8
case int64(int16(i)) == i: return 1+2 // int16 + i16
case int64(int32(i)) == i: return 1+4 // int32 + i32
default: return 1+8 // int64 + u64
}
}
func PutInt64(data []byte, i int64) (n int) {
switch {
// posfixint | negfixint
case -32 <= i && i <= 0b0_1111111:
data[0] = uint8(i)
return 1
// int8 + s8
case int64(int8(i)) == i:
data[0] = op(Int8)
data[1] = uint8(i)
return 1+1
// int16 + s16
case int64(int16(i)) == i:
data[0] = op(Int16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// int32 + s32
case int64(int32(i)) == i:
data[0] = op(Int32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// int64 + s64
default:
data[0] = op(Int64)
binary.BigEndian.PutUint64(data[1:], uint64(i))
return 1+8
}
}
func Uint64Size(i uint64) int {
switch {
case i <= 0x7f: return 1 // posfixint
case i <= 0xff: return 1+1 // uint8 + u8
case i <= 0xffff: return 1+2 // uint16 + u16
case i <= 0xffffffff: return 1+4 // uint32 + u32
default: return 1+8 // uint64 + u64
}
}
func PutUint64(data []byte, i uint64) (n int) {
switch {
// posfixint
case i <= 0x7f:
data[0] = uint8(i)
return 1
// uint8 + u8
case i <= math.MaxUint8:
data[0] = op(Uint8)
data[1] = uint8(i)
return 1+1
// uint16 + be16
case i <= math.MaxUint16:
data[0] = op(Uint16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// uint32 + be32
case i <= math.MaxUint32:
data[0] = op(Uint32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// uint64 + be64
default:
data[0] = op(Uint64)
binary.BigEndian.PutUint64(data[1:], i)
return 1+8
}
}
// BinHeadSize return number of bytes needed to encode header for [l]bin.
func BinHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= math.MaxUint8: return 1+1 // bin8 + len8
case l <= math.MaxUint16: return 1+2 // bin16 + len16
case l <= math.MaxUint32: return 1+4 // bin32 + len32
default: panic("len overflows uint32")
}
}
// PutBinHead puts binary header for [size]bin.
func PutBinHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// bin8 + len8
case l <= 0xff:
data[0] = op(Bin8)
data[1] = uint8(l)
return 1+1
// bin16 + len16
case l <= math.MaxUint16:
data[0] = op(Bin16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// bin32 + len32
case l <= math.MaxUint32:
data[0] = op(Bin32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// ArrayHeadSize returns size for array header for [size]array.
func ArrayHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= 0x0f: return 1 // fixarray
case l <= math.MaxUint16: return 1+2 // array16 + len16
case l <= math.MaxUint32: return 1+4 // array32 + len32
default: panic("len overflows uint32")
}
}
// PutArrayHead puts array header for [size]array.
func PutArrayHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixarray
case l <= 0x0f:
data[0] = op(FixArray_4 | Op(l))
return 1
// array16 + len16
case l <= math.MaxUint16:
data[0] = op(Array16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// array32 + len32
case l <= math.MaxUint32:
data[0] = op(Array32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// MapHeadSize returns size for map header for [size]map.
func MapHeadSize(l int) int {
return ArrayHeadSize(l) // the same 0x0f/len16/len32 scheme
}
// PutMapHead puts map header for [size]map.
func PutMapHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixmap
case l <= 0x0f:
data[0] = op(FixMap_4 | Op(l))
return 1
// map16 + len16
case l <= math.MaxUint16:
data[0] = op(Map16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// map32 + len32
case l <= math.MaxUint32:
data[0] = op(Map32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package msgpack
import (
hexpkg "encoding/hex"
"testing"
)
// hex decodes string as hex; panics on error.
func hex(s string) string {
b, err := hexpkg.DecodeString(s)
if err != nil {
panic(err)
}
return string(b)
}
// tGetPutSize is interface with Get/Put/Size methods, e.g. with
// GetBinHead/PutBinHead/BinHeadSize.
type tGetPutSize interface {
// XXX Get(data []byte) (n int, ret interface{})
Size(arg interface{}) int
Put(data []byte, arg interface{}) int
}
// test1 verifies enc functions on one argument.
func test1(t *testing.T, enc tGetPutSize, arg interface{}, encoded string) {
t.Helper()
data := make([]byte, 16)
n := enc.Put(data, arg)
got := string(data[:n])
if got != encoded {
t.Errorf("%v -> %x ; want %x", arg, got, encoded)
}
if sz := enc.Size(arg); sz != n {
t.Errorf("size(%v) -> %d ; len(data)=%d", arg, sz, n)
}
// XXX decode == arg, n
// XXX decode([:n-1]) -> overflow
}
type tEncUint64 struct{}
func (_ *tEncUint64) Size(xi interface{}) int { return Uint64Size(xi.(uint64)) }
func (_ *tEncUint64) Put(data []byte, xi interface{}) int { return PutUint64(data, xi.(uint64)) }
func TestUint(t *testing.T) {
h := hex
testv := []struct{i uint64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{0x80, h("cc80")}, // uint8
{0xff, h("ccff")},
{0x100, h("cd0100")}, // uint16
{0xffff, h("cdffff")},
{0x10000, h("ce00010000")}, // uint32
{0xffffffff, h("ceffffffff")},
{0x100000000, h("cf0000000100000000")}, // uint64
{0xffffffffffffffff, h("cfffffffffffffffff")},
}
for _, tt := range testv {
test1(t, &tEncUint64{}, tt.i, tt.encoded)
}
}
type tEncInt64 struct{}
func (_ *tEncInt64) Size(xi interface{}) int { return Int64Size(xi.(int64)) }
func (_ *tEncInt64) Put(data []byte, xi interface{}) int { return PutInt64(data, xi.(int64)) }
func TestInt(t *testing.T) {
h := hex
testv := []struct{i int64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{-1, h("ff")}, // negfixint
{-2, h("fe")},
{-31, h("e1")},
{-32, h("e0")},
{-33, h("d0df")}, // int8
{-0x7f, h("d081")},
{-0x80, h("d080")},
{0x80, h("d10080")}, // int16
{0x7fff, h("d17fff")},
{-0x7fff, h("d18001")},
{-0x8000, h("d18000")},
{0x8000, h("d200008000")}, // int32
{0x7fffffff, h("d27fffffff")},
{-0x8001, h("d2ffff7fff")},
{-0x7fffffff, h("d280000001")},
{-0x80000000, h("d280000000")},
{0x80000000, h("d30000000080000000")}, // int64
{0x7fffffffffffffff, h("d37fffffffffffffff")},
{-0x80000001, h("d3ffffffff7fffffff")},
{-0x7fffffffffffffff, h("d38000000000000001")},
{-0x8000000000000000, h("d38000000000000000")},
}
for _, tt := range testv {
test1(t, &tEncInt64{}, tt.i, tt.encoded)
}
}
type tEncBinHead struct{}
func (_ *tEncBinHead) Size(xl interface{}) int { return BinHeadSize(xl.(int)) }
func (_ *tEncBinHead) Put(data []byte, xl interface{}) int { return PutBinHead(data, xl.(int)) }
func TestBin(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("c400")}, // bin8
{1, h("c401")},
{0xff, h("c4ff")},
{0x100, h("c50100")}, // bin16
{0xffff, h("c5ffff")},
{0x10000, h("c600010000")}, // bin32
{0xffffffff, h("c6ffffffff")},
}
for _, tt := range testv {
test1(t, &tEncBinHead{}, tt.l, tt.encoded)
}
}
type tEncArrayHead struct{}
func (_ *tEncArrayHead) Size(xl interface{}) int { return ArrayHeadSize(xl.(int)) }
func (_ *tEncArrayHead) Put(data []byte, xl interface{}) int { return PutArrayHead(data, xl.(int)) }
func TestArray(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("90")}, // fixarray
{1, h("91")},
{14, h("9e")},
{15, h("9f")},
{0x10, h("dc0010")}, // array16
{0x11, h("dc0011")},
{0x100, h("dc0100")},
{0xffff, h("dcffff")},
{0x10000, h("dd00010000")}, // array32
{0xffffffff, h("ddffffffff")},
}
for _, tt := range testv {
test1(t, &tEncArrayHead{}, tt.l, tt.encoded)
}
}
type tEncMapHead struct{}
func (_ *tEncMapHead) Size(xl interface{}) int { return MapHeadSize(xl.(int)) }
func (_ *tEncMapHead) Put(data []byte, xl interface{}) int { return PutMapHead(data, xl.(int)) }
func TestMap(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("80")}, // fixmap
{1, h("81")},
{14, h("8e")},
{15, h("8f")},
{0x10, h("de0010")}, // map16
{0x11, h("de0011")},
{0x100, h("de0100")},
{0xffff, h("deffff")},
{0x10000, h("df00010000")}, // map32
{0xffffffff, h("dfffffffff")},
}
for _, tt := range testv {
test1(t, &tEncMapHead{}, tt.l, tt.encoded)
}
}
...@@ -102,9 +102,12 @@ import ( ...@@ -102,9 +102,12 @@ import (
"lab.nexedi.com/kirr/neo/go/internal/packed" "lab.nexedi.com/kirr/neo/go/internal/packed"
"lab.nexedi.com/kirr/neo/go/internal/xio" "lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
"github.com/philhofer/fwd"
"github.com/someonegg/gocontainer/rbuf" "github.com/someonegg/gocontainer/rbuf"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xbytes" "lab.nexedi.com/kirr/go123/xbytes"
) )
...@@ -124,7 +127,8 @@ import ( ...@@ -124,7 +127,8 @@ import (
// //
// It is safe to use NodeLink from multiple goroutines simultaneously. // It is safe to use NodeLink from multiple goroutines simultaneously.
type NodeLink struct { type NodeLink struct {
peerLink net.Conn // raw conn to peer peerLink net.Conn // raw conn to peer
enc proto.Encoding // protocol encoding in use ('N' or 'M')
connMu sync.Mutex connMu sync.Mutex
connTab map[uint32]*Conn // connId -> Conn associated with connId connTab map[uint32]*Conn // connId -> Conn associated with connId
...@@ -153,7 +157,8 @@ type NodeLink struct { ...@@ -153,7 +157,8 @@ type NodeLink struct {
axclosed atomic32 // whether CloseAccept was called axclosed atomic32 // whether CloseAccept was called
closed atomic32 // whether Close was called closed atomic32 // whether Close was called
rxbuf rbuf.RingBuf // buffer for reading from peerLink rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding)
rxbufM *msgp.Reader // ----//---- (M encoding)
// scheduling optimization: whenever serveRecv sends to Conn.rxq // scheduling optimization: whenever serveRecv sends to Conn.rxq
// receiving side must ack here to receive G handoff. // receiving side must ack here to receive G handoff.
...@@ -250,6 +255,8 @@ const ( ...@@ -250,6 +255,8 @@ const (
// newNodeLink makes a new NodeLink from already established net.Conn . // newNodeLink makes a new NodeLink from already established net.Conn .
// //
// On the wire messages will be encoded according to enc.
//
// Role specifies how to treat our role on the link - either as client or // Role specifies how to treat our role on the link - either as client or
// server. The difference in between client and server roles is in: // server. The difference in between client and server roles is in:
// //
...@@ -262,7 +269,9 @@ const ( ...@@ -262,7 +269,9 @@ const (
// //
// Though it is possible to wrap just-established raw connection into NodeLink, // Though it is possible to wrap just-established raw connection into NodeLink,
// users should always use Handshake which performs protocol handshaking first. // users should always use Handshake which performs protocol handshaking first.
func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink { //
// rxbuf if != nil indicates what was already read-buffered from conn.
func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.Reader) *NodeLink {
var nextConnId uint32 var nextConnId uint32
switch role &^ linkFlagsMask { switch role &^ linkFlagsMask {
case _LinkServer: case _LinkServer:
...@@ -275,6 +284,7 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink { ...@@ -275,6 +284,7 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
nl := &NodeLink{ nl := &NodeLink{
peerLink: conn, peerLink: conn,
enc: enc,
connTab: map[uint32]*Conn{}, connTab: map[uint32]*Conn{},
nextConnId: nextConnId, nextConnId: nextConnId,
acceptq: make(chan *Conn), // XXX +buf ? acceptq: make(chan *Conn), // XXX +buf ?
...@@ -283,6 +293,25 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink { ...@@ -283,6 +293,25 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
// axdown: make(chan struct{}), // axdown: make(chan struct{}),
down: make(chan struct{}), down: make(chan struct{}),
} }
if rxbuf == nil {
rxbuf = fwd.NewReader(conn)
}
switch enc {
case 'N':
// rxbufN <- rxbufM (what was preread)
b, err := rxbuf.Next(rxbuf.Buffered())
if err != nil {
panic(err) // must not fail
}
nl.rxbufN.Write(b)
case 'M':
nl.rxbufM = &msgp.Reader{R: rxbuf}
default:
panic("bug")
}
if role&linkNoRecvSend == 0 { if role&linkNoRecvSend == 0 {
nl.serveWg.Add(2) nl.serveWg.Add(2)
go nl.serveRecv() go nl.serveRecv()
...@@ -1038,12 +1067,14 @@ func (c *Conn) sendPkt(pkt *pktBuf) error { ...@@ -1038,12 +1067,14 @@ func (c *Conn) sendPkt(pkt *pktBuf) error {
func (c *Conn) sendPkt2(pkt *pktBuf) error { func (c *Conn) sendPkt2(pkt *pktBuf) error {
// connId must be set to one associated with this connection // connId must be set to one associated with this connection
if pkt.Header().ConnId != packed.Hton32(c.connId) { connID, _, _, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
}
if connID != c.connId {
panic("Conn.sendPkt: connId wrong") panic("Conn.sendPkt: connId wrong")
} }
var err error
select { select {
case <-c.txdown: case <-c.txdown:
return c.errSendShutdown() return c.errSendShutdown()
...@@ -1173,8 +1204,24 @@ var ErrPktTooBig = errors.New("packet too big") ...@@ -1173,8 +1204,24 @@ var ErrPktTooBig = errors.New("packet too big")
// rx error, if any, is returned as is and is analyzed in serveRecv // rx error, if any, is returned as is and is analyzed in serveRecv
// //
// XXX dup in ZEO. // XXX dup in ZEO.
func (nl *NodeLink) recvPkt() (*pktBuf, error) { func (nl *NodeLink) recvPkt() (pkt *pktBuf, err error) {
// FIXME if rxbuf is non-empty - first look there for header and then if switch nl.enc {
case 'N': pkt, err = nl.recvPktN()
case 'M': pkt, err = nl.recvPktM()
default: panic("bug")
}
if dumpio {
// XXX -> log
fmt.Printf("%v < %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
}
return pkt, err
}
func (nl *NodeLink) recvPktN() (*pktBuf, error) {
// FIXME if rxbufN is non-empty - first look there for header and then if
// we know size -> allocate pkt with that size. // we know size -> allocate pkt with that size.
pkt := pktAlloc(4096) pkt := pktAlloc(4096)
// len=4K but cap can be more since pkt is from pool - use all space to buffer reads // len=4K but cap can be more since pkt is from pool - use all space to buffer reads
...@@ -1184,8 +1231,8 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) { ...@@ -1184,8 +1231,8 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n := 0 // number of pkt bytes obtained so far n := 0 // number of pkt bytes obtained so far
// next packet could be already prefetched in part by previous read // next packet could be already prefetched in part by previous read
if nl.rxbuf.Len() > 0 { if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[:proto.PktHeaderLen]) δn, _ := nl.rxbufN.Read(data[:proto.PktHeaderLen])
n += δn n += δn
} }
...@@ -1198,7 +1245,7 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) { ...@@ -1198,7 +1245,7 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n += δn n += δn
} }
pkth := pkt.Header() pkth := pkt.HeaderN()
msgLen := packed.Ntoh32(pkth.MsgLen) msgLen := packed.Ntoh32(pkth.MsgLen)
if msgLen > proto.PktMaxSize - proto.PktHeaderLen { if msgLen > proto.PktMaxSize - proto.PktHeaderLen {
...@@ -1210,9 +1257,9 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) { ...@@ -1210,9 +1257,9 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
data = xbytes.Resize(data, pktLen) data = xbytes.Resize(data, pktLen)
data = data[:cap(data)] data = data[:cap(data)]
// we might have more data already prefetched in rxbuf // we might have more data already prefetched in rxbufN
if nl.rxbuf.Len() > 0 { if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[n:pktLen]) δn, _ := nl.rxbufN.Read(data[n:pktLen])
n += δn n += δn
} }
...@@ -1225,20 +1272,26 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) { ...@@ -1225,20 +1272,26 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n += δn n += δn
} }
// put overread data into rxbuf for next reader // put overread data into rxbufN for next reader
if n > pktLen { if n > pktLen {
nl.rxbuf.Write(data[pktLen:n]) nl.rxbufN.Write(data[pktLen:n])
} }
// fixup data/pkt // fixup data/pkt
data = data[:n] data = data[:n]
pkt.data = data pkt.data = data
if dumpio { return pkt, nil
// XXX -> log }
fmt.Printf("%v < %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
}
func (nl *NodeLink) recvPktM() (*pktBuf, error) {
pkt := pktAlloc(4096)
mraw := msgp.Raw(pkt.data)
err := mraw.DecodeMsg(nl.rxbufM) // XXX limit size of one packet to e.g. 0x4000000 (= UNPACK_BUFFER_SIZE in NEO/py speak)
if err != nil {
return nil, err
}
pkt.data = []byte(mraw)
return pkt, nil return pkt, nil
} }
...@@ -1313,21 +1366,104 @@ func (c *Conn) err(op string, e error) error { ...@@ -1313,21 +1366,104 @@ func (c *Conn) err(op string, e error) error {
//trace:event traceMsgSendPre(l *NodeLink, connId uint32, msg proto.Msg) //trace:event traceMsgSendPre(l *NodeLink, connId uint32, msg proto.Msg)
// XXX do we also need traceConnSend? // XXX do we also need traceConnSend?
// msgPack allocates pktBuf and encodes msg into it.
func msgPack(connId uint32, msg proto.Msg) *pktBuf { // XXX think again; XXX move to proto?
l := msg.NEOMsgEncodedLen() const (
encN = proto.Encoding('N')
encM = proto.Encoding('M')
)
// pktEncode allocates pktBuf and encodes msg into it.
//func (e Encoding) pktEncode(connId uint32, msg proto.Msg) *pktBuf {
// XXX move to proto ? -> YES: Encoding.PktEncode + .PktDecode
func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf {
switch e {
case 'N': return pktEncodeN(connId, msg)
case 'M': return pktEncodeM(connId, msg)
default: panic("bug")
}
}
// pktDecodeHead decodes header of a packet.
func pktDecodeHead(e proto.Encoding, pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
switch e {
case 'N': return pktDecodeHeadN(pkt)
case 'M': return pktDecodeHeadM(pkt)
default: panic("bug")
}
}
func pktEncodeN(connId uint32, msg proto.Msg) *pktBuf {
l := encN.NEOMsgEncodedLen(msg)
buf := pktAlloc(proto.PktHeaderLen + l) buf := pktAlloc(proto.PktHeaderLen + l)
h := buf.Header() h := buf.HeaderN()
h.ConnId = packed.Hton32(connId) h.ConnId = packed.Hton32(connId)
h.MsgCode = packed.Hton16(msg.NEOMsgCode()) h.MsgCode = packed.Hton16(proto.MsgCode(msg))
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
msg.NEOMsgEncode(buf.Payload()) encN.NEOMsgEncode(msg, buf.PayloadN())
return buf return buf
} }
// TODO msgUnpack func pktDecodeHeadN(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
pkth := pkt.HeaderN()
connID = packed.Ntoh32(pkth.ConnId)
msgCode = packed.Ntoh16(pkth.MsgCode)
msgLen := packed.Ntoh32(pkth.MsgLen)
payload = pkt.PayloadN()
if len(payload) != int(msgLen) {
return 0, 0, nil, fmt.Errorf("len(payload) != msgLen")
}
return
}
func pktEncodeM(connId uint32, msg proto.Msg) *pktBuf {
// [3](connID, msgCode, argv)
msgCode := proto.MsgCode(msg)
hroom := msgpack.ArrayHeadSize(3) +
msgpack.Uint32Size(connId) +
msgpack.Uint16Size(msgCode)
l := encM.NEOMsgEncodedLen(msg)
buf := pktAlloc(hroom + l)
b := buf.data
i := 0
i += msgpack.PutArrayHead (b[i:], 3)
i += msgpack.PutUint32 (b[i:], connId)
i += msgpack.PutUint16 (b[i:], msgCode)
if i != hroom {
panic("bug")
}
encM.NEOMsgEncode(msg, b[hroom:])
return buf
}
func pktDecodeHeadM(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
// XXX errctx = "decode header"
b := pkt.data
sz, b, err := msgp.ReadArrayHeaderBytes(b)
if err != nil {
return 0, 0, nil, err
}
if sz != 3 {
return 0, 0, nil, fmt.Errorf("expected [3]tuple, got [%d]tuple", sz)
}
connID, b, err = msgp.ReadUint32Bytes(b)
if err != nil {
return 0, 0, nil, fmt.Errorf("connID: %s", err)
}
msgCode, b, err = msgp.ReadUint16Bytes(b)
if err != nil {
return 0, 0, nil, fmt.Errorf("msgCode: %s", err)
}
return connID, msgCode, b, nil
}
// Recv receives message from the connection. // Recv receives message from the connection.
func (c *Conn) Recv() (proto.Msg, error) { func (c *Conn) Recv() (proto.Msg, error) {
...@@ -1338,8 +1474,11 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1338,8 +1474,11 @@ func (c *Conn) Recv() (proto.Msg, error) {
defer pkt.Free() defer pkt.Free()
// decode packet // decode packet
pkth := pkt.Header() _, msgCode, payload, err := pktDecodeHead(c.link.enc, pkt)
msgCode := packed.Ntoh16(pkth.MsgCode) if err != nil {
return nil, err
}
msgType := proto.MsgType(msgCode) msgType := proto.MsgType(msgCode)
if msgType == nil { if msgType == nil {
err := fmt.Errorf("invalid msgCode (%d)", msgCode) err := fmt.Errorf("invalid msgCode (%d)", msgCode)
...@@ -1352,7 +1491,7 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1352,7 +1491,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
// msg := reflect.NewAt(msgType, bufAlloc(msgType.Size()) // msg := reflect.NewAt(msgType, bufAlloc(msgType.Size())
_, err = msg.NEOMsgDecode(pkt.Payload()) _, err = c.link.enc.NEOMsgDecode(msg, payload)
if err != nil { if err != nil {
return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow
} }
...@@ -1369,7 +1508,8 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1369,7 +1508,8 @@ func (c *Conn) Recv() (proto.Msg, error) {
func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error { func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error {
traceMsgSendPre(link, connId, msg) traceMsgSendPre(link, connId, msg)
buf := msgPack(connId, msg) //buf := msgPack(connId, msg)
buf := pktEncode(link.enc, connId, msg)
return link.sendPkt(buf) // XXX more context in err? (msg type) return link.sendPkt(buf) // XXX more context in err? (msg type)
// FIXME ^^^ shutdown whole link on error // FIXME ^^^ shutdown whole link on error
} }
...@@ -1378,7 +1518,8 @@ func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error { ...@@ -1378,7 +1518,8 @@ func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error {
func (c *Conn) Send(msg proto.Msg) error { func (c *Conn) Send(msg proto.Msg) error {
traceMsgSendPre(c.link, c.connId, msg) traceMsgSendPre(c.link, c.connId, msg)
buf := msgPack(c.connId, msg) //buf := msgPack(c.connId, msg)
buf := pktEncode(c.link.enc, c.connId, msg)
return c.sendPkt(buf) // XXX more context in err? (msg type) return c.sendPkt(buf) // XXX more context in err? (msg type)
} }
...@@ -1401,12 +1542,13 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) { ...@@ -1401,12 +1542,13 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) {
} }
defer pkt.Free() defer pkt.Free()
// XXX encN-specific
pkth := pkt.Header() pkth := pkt.Header()
msgCode := packed.Ntoh16(pkth.MsgCode) msgCode := packed.Ntoh16(pkth.MsgCode)
for i, msg := range msgv { for i, msg := range msgv {
if msg.NEOMsgCode() == msgCode { if proto.MsgCode(msg) == msgCode {
_, err := msg.NEOMsgDecode(pkt.Payload()) _, err := c.link.enc.NEOMsgDecode(msg, pkt.Payload())
if err != nil { if err != nil {
return -1, c.err("decode", err) return -1, c.err("decode", err)
} }
......
...@@ -22,6 +22,7 @@ package neonet ...@@ -22,6 +22,7 @@ package neonet
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"io" "io"
"net" "net"
"reflect" "reflect"
...@@ -38,10 +39,30 @@ import ( ...@@ -38,10 +39,30 @@ import (
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
"lab.nexedi.com/kirr/neo/go/zodb" "lab.nexedi.com/kirr/neo/go/zodb"
"github.com/tinylib/msgp/msgp"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// T is neonet testing environment.
type T struct {
*testing.T
enc proto.Encoding // encoding to use for messages exchange
}
// Verify tests f for all possible environments.
func Verify(t *testing.T, f func(*T)) {
// for each encoding
for _, enc := range []proto.Encoding{'N', 'M'} {
t.Run(fmt.Sprintf("enc=%c", enc), func(t *testing.T) {
f(&T{t, enc})
})
}
}
func xclose(c io.Closer) { func xclose(c io.Closer) {
err := c.Close() err := c.Close()
exc.Raiseif(err) exc.Raiseif(err)
...@@ -102,48 +123,70 @@ func xconnError(err error) error { ...@@ -102,48 +123,70 @@ func xconnError(err error) error {
} }
// Prepare pktBuf with content. // Prepare pktBuf with content.
func _mkpkt(connid uint32, msgcode uint16, payload []byte) *pktBuf { func _mkpkt(enc proto.Encoding, connid uint32, msgcode uint16, payload []byte) *pktBuf {
pkt := &pktBuf{make([]byte, proto.PktHeaderLen+len(payload))} switch enc {
h := pkt.Header() case 'N':
h.ConnId = packed.Hton32(connid) pkt := &pktBuf{make([]byte, proto.PktHeaderLen+len(payload))}
h.MsgCode = packed.Hton16(msgcode) h := pkt.HeaderN()
h.MsgLen = packed.Hton32(uint32(len(payload))) h.ConnId = packed.Hton32(connid)
copy(pkt.Payload(), payload) h.MsgCode = packed.Hton16(msgcode)
return pkt h.MsgLen = packed.Hton32(uint32(len(payload)))
copy(pkt.PayloadN(), payload)
return pkt
case 'M':
var b []byte
b = msgp.AppendArrayHeader (b, 3)
b = msgp.AppendUint32 (b, connid)
b = msgp.AppendUint16 (b, msgcode)
// NOTE payload is appended wrapped into bin object. We need
// this not to break framing, because in M-encoding whole
// packet must be a valid msgpack object.
b = msgp.AppendBytes (b, payload)
return &pktBuf{b}
default:
panic("bug")
}
} }
func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf { func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf {
// in Conn exchange connid is automatically set by Conn.sendPkt // in Conn exchange connid is automatically set by Conn.sendPkt
return _mkpkt(c.connId, msgcode, payload) return _mkpkt(c.link.enc, c.connId, msgcode, payload)
} }
// Verify pktBuf is as expected. // Verify pktBuf is as expected.
func xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) { func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) {
errv := xerr.Errorv{} errv := xerr.Errorv{}
h := pkt.Header()
pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(t.enc, pkt)
exc.Raiseif(err)
// TODO include caller location // TODO include caller location
if packed.Ntoh32(h.ConnId) != connid { if pktConnID != connid {
errv.Appendf("header: unexpected connid %v (want %v)", packed.Ntoh32(h.ConnId), connid) errv.Appendf("header: unexpected connid %v (want %v)", pktConnID, connid)
} }
if packed.Ntoh16(h.MsgCode) != msgcode { if pktMsgCode != msgcode {
errv.Appendf("header: unexpected msgcode %v (want %v)", packed.Ntoh16(h.MsgCode), msgcode) errv.Appendf("header: unexpected msgcode %v (want %v)", pktMsgCode, msgcode)
} }
if packed.Ntoh32(h.MsgLen) != uint32(len(payload)) {
errv.Appendf("header: unexpected msglen %v (want %v)", packed.Ntoh32(h.MsgLen), len(payload)) // M-encoding -> wrap payloadOK into bin (see _mkpkt ^^^ for why)
if t.enc == 'M' {
payload = msgp.AppendBytes(nil, payload)
} }
if !bytes.Equal(pkt.Payload(), payload) { if !bytes.Equal(pktPayload, payload) {
errv.Appendf("payload differ:\n%s", errv.Appendf("payload differ:\n%s",
pretty.Compare(string(payload), string(pkt.Payload()))) pretty.Compare(string(payload), string(pktPayload)))
} }
exc.Raiseif(errv.Err()) exc.Raiseif(errv.Err())
} }
// Verify pktBuf to match expected message. // Verify pktBuf to match expected message.
func xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) { func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) {
data := make([]byte, msg.NEOMsgEncodedLen()) data := make([]byte, t.enc.NEOMsgEncodedLen(msg))
msg.NEOMsgEncode(data) t.enc.NEOMsgEncode(msg, data)
xverifyPkt(pkt, connid, msg.NEOMsgCode(), data) t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data)
} }
// delay a bit. // delay a bit.
...@@ -160,24 +203,27 @@ func tdelay() { ...@@ -160,24 +203,27 @@ func tdelay() {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
// create NodeLinks connected via net.Pipe // create NodeLinks connected via net.Pipe; messages are encoded via t.enc.
func _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) { func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) {
node1, node2 := net.Pipe() node1, node2 := net.Pipe()
nl1 = newNodeLink(node1, _LinkClient|flags1) nl1 = newNodeLink(node1, t.enc, _LinkClient|flags1, nil)
nl2 = newNodeLink(node2, _LinkServer|flags2) nl2 = newNodeLink(node2, t.enc, _LinkServer|flags2, nil)
return nl1, nl2 return nl1, nl2
} }
func nodeLinkPipe() (nl1, nl2 *NodeLink) { func (t *T) nodeLinkPipe() (nl1, nl2 *NodeLink) {
return _nodeLinkPipe(0, 0) return t._nodeLinkPipe(0, 0)
} }
func TestNodeLink(t *testing.T) { func TestNodeLink(t *testing.T) {
Verify(t, _TestNodeLink)
}
func _TestNodeLink(t *T) {
// TODO catch exception -> add proper location from it -> t.Fatal (see git-backup) // TODO catch exception -> add proper location from it -> t.Fatal (see git-backup)
bg := context.Background() bg := context.Background()
// Close vs recvPkt // Close vs recvPkt
nl1, nl2 := _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend) nl1, nl2 := t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg := xsync.NewWorkGroup(bg) wg := xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
tdelay() tdelay()
...@@ -191,7 +237,7 @@ func TestNodeLink(t *testing.T) { ...@@ -191,7 +237,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2) xclose(nl2)
// Close vs sendPkt // Close vs sendPkt
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
tdelay() tdelay()
...@@ -206,7 +252,7 @@ func TestNodeLink(t *testing.T) { ...@@ -206,7 +252,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2) xclose(nl2)
// {Close,CloseAccept} vs Accept // {Close,CloseAccept} vs Accept
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
tdelay() tdelay()
...@@ -234,7 +280,7 @@ func TestNodeLink(t *testing.T) { ...@@ -234,7 +280,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl1) xclose(nl1)
// Close vs recvPkt on another side // Close vs recvPkt on another side
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
tdelay() tdelay()
...@@ -248,7 +294,7 @@ func TestNodeLink(t *testing.T) { ...@@ -248,7 +294,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl1) xclose(nl1)
// Close vs sendPkt on another side // Close vs sendPkt on another side
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
tdelay() tdelay()
...@@ -263,23 +309,23 @@ func TestNodeLink(t *testing.T) { ...@@ -263,23 +309,23 @@ func TestNodeLink(t *testing.T) {
xclose(nl1) xclose(nl1)
// raw exchange // raw exchange
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
okch := make(chan int, 2) okch := make(chan int, 2)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
// send ping; wait for pong // send ping; wait for pong
pkt := _mkpkt(1, 2, []byte("ping")) pkt := _mkpkt(t.enc, 1, 2, []byte("ping"))
xsendPkt(nl1, pkt) xsendPkt(nl1, pkt)
pkt = xrecvPkt(nl1) pkt = xrecvPkt(nl1)
xverifyPkt(pkt, 3, 4, []byte("pong")) t.xverifyPkt(pkt, 3, 4, []byte("pong"))
okch <- 1 okch <- 1
}) })
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
// wait for ping; send pong // wait for ping; send pong
pkt = xrecvPkt(nl2) pkt = xrecvPkt(nl2)
xverifyPkt(pkt, 1, 2, []byte("ping")) t.xverifyPkt(pkt, 1, 2, []byte("ping"))
pkt = _mkpkt(3, 4, []byte("pong")) pkt = _mkpkt(t.enc, 3, 4, []byte("pong"))
xsendPkt(nl2, pkt) xsendPkt(nl2, pkt)
okch <- 2 okch <- 2
}) })
...@@ -309,7 +355,7 @@ func TestNodeLink(t *testing.T) { ...@@ -309,7 +355,7 @@ func TestNodeLink(t *testing.T) {
// ---- connections on top of nodelink ---- // ---- connections on top of nodelink ----
// Close vs recvPkt // Close vs recvPkt
nl1, nl2 = _nodeLinkPipe(0, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(0, linkNoRecvSend)
c = xnewconn(nl1) c = xnewconn(nl1)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
...@@ -325,7 +371,7 @@ func TestNodeLink(t *testing.T) { ...@@ -325,7 +371,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2) xclose(nl2)
// Close vs sendPkt // Close vs sendPkt
nl1, nl2 = _nodeLinkPipe(0, linkNoRecvSend) nl1, nl2 = t._nodeLinkPipe(0, linkNoRecvSend)
c = xnewconn(nl1) c = xnewconn(nl1)
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
...@@ -364,7 +410,7 @@ func TestNodeLink(t *testing.T) { ...@@ -364,7 +410,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2) xclose(nl2)
// NodeLink.Close vs Conn.sendPkt/recvPkt and Accept on another side // NodeLink.Close vs Conn.sendPkt/recvPkt and Accept on another side
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, 0) nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, 0)
c21 := xnewconn(nl2) c21 := xnewconn(nl2)
c22 := xnewconn(nl2) c22 := xnewconn(nl2)
c23 := xnewconn(nl2) c23 := xnewconn(nl2)
...@@ -482,7 +528,7 @@ func TestNodeLink(t *testing.T) { ...@@ -482,7 +528,7 @@ func TestNodeLink(t *testing.T) {
connKeepClosed = 10 * time.Millisecond connKeepClosed = 10 * time.Millisecond
// Conn accept + exchange // Conn accept + exchange
nl1, nl2 = nodeLinkPipe() nl1, nl2 = t.nodeLinkPipe()
nl1.CloseAccept() nl1.CloseAccept()
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
closed := make(chan int) closed := make(chan int)
...@@ -490,14 +536,14 @@ func TestNodeLink(t *testing.T) { ...@@ -490,14 +536,14 @@ func TestNodeLink(t *testing.T) {
c := xaccept(nl2) c := xaccept(nl2)
pkt := xrecvPkt(c) pkt := xrecvPkt(c)
xverifyPkt(pkt, c.connId, 33, []byte("ping")) t.xverifyPkt(pkt, c.connId, 33, []byte("ping"))
// change pkt a bit and send it back // change pkt a bit and send it back
xsendPkt(c, c.mkpkt(34, []byte("pong"))) xsendPkt(c, c.mkpkt(34, []byte("pong")))
// one more time // one more time
pkt = xrecvPkt(c) pkt = xrecvPkt(c)
xverifyPkt(pkt, c.connId, 35, []byte("ping2")) t.xverifyPkt(pkt, c.connId, 35, []byte("ping2"))
xsendPkt(c, c.mkpkt(36, []byte("pong2"))) xsendPkt(c, c.mkpkt(36, []byte("pong2")))
xclose(c) xclose(c)
...@@ -506,7 +552,7 @@ func TestNodeLink(t *testing.T) { ...@@ -506,7 +552,7 @@ func TestNodeLink(t *testing.T) {
// once again as ^^^ but finish only with CloseRecv // once again as ^^^ but finish only with CloseRecv
c2 := xaccept(nl2) c2 := xaccept(nl2)
pkt = xrecvPkt(c2) pkt = xrecvPkt(c2)
xverifyPkt(pkt, c2.connId, 41, []byte("ping5")) t.xverifyPkt(pkt, c2.connId, 41, []byte("ping5"))
xsendPkt(c2, c2.mkpkt(42, []byte("pong5"))) xsendPkt(c2, c2.mkpkt(42, []byte("pong5")))
c2.CloseRecv() c2.CloseRecv()
...@@ -516,10 +562,10 @@ func TestNodeLink(t *testing.T) { ...@@ -516,10 +562,10 @@ func TestNodeLink(t *testing.T) {
c = xnewconn(nl2) // XXX should get error here? c = xnewconn(nl2) // XXX should get error here?
xsendPkt(c, c.mkpkt(38, []byte("pong3"))) xsendPkt(c, c.mkpkt(38, []byte("pong3")))
pkt = xrecvPkt(c) pkt = xrecvPkt(c)
xverifyPktMsg(pkt, c.connId, errConnRefused) t.xverifyPktMsg(pkt, c.connId, errConnRefused)
xsendPkt(c, c.mkpkt(40, []byte("pong4"))) // once again xsendPkt(c, c.mkpkt(40, []byte("pong4"))) // once again
pkt = xrecvPkt(c) pkt = xrecvPkt(c)
xverifyPktMsg(pkt, c.connId, errConnRefused) t.xverifyPktMsg(pkt, c.connId, errConnRefused)
xclose(c) xclose(c)
...@@ -528,30 +574,30 @@ func TestNodeLink(t *testing.T) { ...@@ -528,30 +574,30 @@ func TestNodeLink(t *testing.T) {
c1 := xnewconn(nl1) c1 := xnewconn(nl1)
xsendPkt(c1, c1.mkpkt(33, []byte("ping"))) xsendPkt(c1, c1.mkpkt(33, []byte("ping")))
pkt = xrecvPkt(c1) pkt = xrecvPkt(c1)
xverifyPkt(pkt, c1.connId, 34, []byte("pong")) t.xverifyPkt(pkt, c1.connId, 34, []byte("pong"))
xsendPkt(c1, c1.mkpkt(35, []byte("ping2"))) xsendPkt(c1, c1.mkpkt(35, []byte("ping2")))
pkt = xrecvPkt(c1) pkt = xrecvPkt(c1)
xverifyPkt(pkt, c1.connId, 36, []byte("pong2")) t.xverifyPkt(pkt, c1.connId, 36, []byte("pong2"))
// "connection closed" after peer closed its end // "connection closed" after peer closed its end
<-closed <-closed
xsendPkt(c1, c1.mkpkt(37, []byte("ping3"))) xsendPkt(c1, c1.mkpkt(37, []byte("ping3")))
pkt = xrecvPkt(c1) pkt = xrecvPkt(c1)
xverifyPktMsg(pkt, c1.connId, errConnClosed) t.xverifyPktMsg(pkt, c1.connId, errConnClosed)
xsendPkt(c1, c1.mkpkt(39, []byte("ping4"))) // once again xsendPkt(c1, c1.mkpkt(39, []byte("ping4"))) // once again
pkt = xrecvPkt(c1) pkt = xrecvPkt(c1)
xverifyPktMsg(pkt, c1.connId, errConnClosed) t.xverifyPktMsg(pkt, c1.connId, errConnClosed)
// XXX also should get EOF on recv // XXX also should get EOF on recv
// one more time but now peer does only .CloseRecv() // one more time but now peer does only .CloseRecv()
c2 := xnewconn(nl1) c2 := xnewconn(nl1)
xsendPkt(c2, c2.mkpkt(41, []byte("ping5"))) xsendPkt(c2, c2.mkpkt(41, []byte("ping5")))
pkt = xrecvPkt(c2) pkt = xrecvPkt(c2)
xverifyPkt(pkt, c2.connId, 42, []byte("pong5")) t.xverifyPkt(pkt, c2.connId, 42, []byte("pong5"))
<-closed <-closed
xsendPkt(c2, c2.mkpkt(41, []byte("ping6"))) xsendPkt(c2, c2.mkpkt(41, []byte("ping6")))
pkt = xrecvPkt(c2) pkt = xrecvPkt(c2)
xverifyPktMsg(pkt, c2.connId, errConnClosed) t.xverifyPktMsg(pkt, c2.connId, errConnClosed)
xwait(wg) xwait(wg)
...@@ -577,7 +623,7 @@ func TestNodeLink(t *testing.T) { ...@@ -577,7 +623,7 @@ func TestNodeLink(t *testing.T) {
connKeepClosed = saveKeepClosed connKeepClosed = saveKeepClosed
// test 2 channels with replies coming in reversed time order // test 2 channels with replies coming in reversed time order
nl1, nl2 = nodeLinkPipe() nl1, nl2 = t.nodeLinkPipe()
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
replyOrder := map[uint16]struct { // "order" in which to process requests replyOrder := map[uint16]struct { // "order" in which to process requests
start chan struct{} // processing starts when start chan is ready start chan struct{} // processing starts when start chan is ready
...@@ -594,6 +640,7 @@ func TestNodeLink(t *testing.T) { ...@@ -594,6 +640,7 @@ func TestNodeLink(t *testing.T) {
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
pkt := xrecvPkt(c) pkt := xrecvPkt(c)
// XXX encN-specific
n := packed.Ntoh16(pkt.Header().MsgCode) n := packed.Ntoh16(pkt.Header().MsgCode)
x := replyOrder[n] x := replyOrder[n]
...@@ -619,7 +666,7 @@ func TestNodeLink(t *testing.T) { ...@@ -619,7 +666,7 @@ func TestNodeLink(t *testing.T) {
// replies must be coming in reverse order // replies must be coming in reverse order
xechoWait := func(c *Conn, msgCode uint16) { xechoWait := func(c *Conn, msgCode uint16) {
pkt := xrecvPkt(c) pkt := xrecvPkt(c)
xverifyPkt(pkt, c.connId, msgCode, []byte("")) t.xverifyPkt(pkt, c.connId, msgCode, []byte(""))
} }
xechoWait(c2, 2) xechoWait(c2, 2)
xechoWait(c1, 1) xechoWait(c1, 1)
...@@ -663,10 +710,13 @@ func xverifyMsg(msg1, msg2 proto.Msg) { ...@@ -663,10 +710,13 @@ func xverifyMsg(msg1, msg2 proto.Msg) {
} }
func TestRecv1Mode(t *testing.T) { func TestRecv1Mode(t *testing.T) {
Verify(t, _TestRecv1Mode)
}
func _TestRecv1Mode(t *T) {
bg := context.Background() bg := context.Background()
// Send1 // Send1
nl1, nl2 := nodeLinkPipe() nl1, nl2 := t.nodeLinkPipe()
wg := xsync.NewWorkGroup(bg) wg := xsync.NewWorkGroup(bg)
sync := make(chan int) sync := make(chan int)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
...@@ -730,7 +780,10 @@ func TestRecv1Mode(t *testing.T) { ...@@ -730,7 +780,10 @@ func TestRecv1Mode(t *testing.T) {
// //
// bug triggers under -race. // bug triggers under -race.
func TestLightCloseVsLinkShutdown(t *testing.T) { func TestLightCloseVsLinkShutdown(t *testing.T) {
nl1, nl2 := nodeLinkPipe() Verify(t, _TestLightCloseVsLinkShutdown)
}
func _TestLightCloseVsLinkShutdown(t *T) {
nl1, nl2 := t.nodeLinkPipe()
wg := xsync.NewWorkGroup(context.Background()) wg := xsync.NewWorkGroup(context.Background())
c := xnewconn(nl1) c := xnewconn(nl1)
......
...@@ -21,18 +21,38 @@ package neonet ...@@ -21,18 +21,38 @@ package neonet
// link establishment // link establishment
import ( import (
"bytes"
"context" "context"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync" "os"
"github.com/philhofer/fwd"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xerr"
"lab.nexedi.com/kirr/go123/xnet" "lab.nexedi.com/kirr/go123/xnet"
"lab.nexedi.com/kirr/neo/go/internal/xcontext"
"lab.nexedi.com/kirr/neo/go/internal/xio" "lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
) )
// encDefault is default encoding to use.
// XXX we don't need this? (just set encDefault = 'M')
var encDefault = proto.Encoding('N') // XXX = 'M' instead?
func init() {
e := os.Getenv("NEO_ENCODING")
switch e {
case "": // not set
case "N": fallthrough
case "M": encDefault = proto.Encoding(e[0])
default:
fmt.Fprintf(os.Stderr, "E: $NEO_ENCODING=%q - invalid -> abort", e)
os.Exit(1)
}
}
// ---- Handshake ---- // ---- Handshake ----
// XXX _Handshake may be needed to become public in case when we have already // XXX _Handshake may be needed to become public in case when we have already
...@@ -45,91 +65,215 @@ import ( ...@@ -45,91 +65,215 @@ import (
// On success raw connection is returned wrapped into NodeLink. // On success raw connection is returned wrapped into NodeLink.
// On error raw connection is closed. // On error raw connection is closed.
func _Handshake(ctx context.Context, conn net.Conn, role _LinkRole) (nl *NodeLink, err error) { func _Handshake(ctx context.Context, conn net.Conn, role _LinkRole) (nl *NodeLink, err error) {
err = handshake(ctx, conn, proto.Version) enc := encDefault // default encoding
var rxbuf *fwd.Reader
switch role &^ linkFlagsMask {
case _LinkServer:
enc, rxbuf, err = handshakeServer(ctx, conn, proto.Version)
case _LinkClient:
enc, rxbuf, err = handshakeClient(ctx, conn, proto.Version, enc)
default:
panic("bug")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
// handshake ok -> NodeLink // handshake ok -> NodeLink
return newNodeLink(conn, role), nil return newNodeLink(conn, enc, role, rxbuf), nil
} }
// _HandshakeError is returned when there is an error while performing handshake. // _HandshakeError is returned when there is an error while performing handshake.
type _HandshakeError struct { type _HandshakeError struct {
LocalRole _LinkRole
LocalAddr net.Addr LocalAddr net.Addr
RemoteAddr net.Addr RemoteAddr net.Addr
Err error Err error
} }
func (e *_HandshakeError) Error() string { func (e *_HandshakeError) Error() string {
return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error()) role := ""
switch e.LocalRole {
case _LinkServer: role = "server"
case _LinkClient: role = "client"
default: panic("bug")
}
return fmt.Sprintf("%s - %s: handshake (%s): %s", e.LocalAddr, e.RemoteAddr, role, e.Err.Error())
} }
func handshake(ctx context.Context, conn net.Conn, version uint32) (err error) { func (e *_HandshakeError) Cause() error { return e.Err }
// XXX simplify -> errgroup func (e *_HandshakeError) Unwrap() error { return e.Err }
errch := make(chan error, 2)
// tx handshake word
txWg := sync.WaitGroup{}
txWg.Add(1)
go func() {
var b [4]byte
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ?
_, err := conn.Write(b[:])
// XXX EOF -> ErrUnexpectedEOF ?
errch <- err
txWg.Done()
}()
// rx handshake word // handshakeClient implements client-side handshake.
go func() { //
var b [4]byte // Client indicates its version and preferred encoding, but accepts any
_, err := io.ReadFull(conn, b[:]) // encoding choosen to use by server.
err = xio.NoEOF(err) // can be returned with n = 0 func handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
if err == nil { defer func() {
peerVersion := binary.BigEndian.Uint32(b[:]) // XXX -> ntoh32 ? if err != nil {
if peerVersion != version { err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err}
err = fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVersion, version)
}
} }
errch <- err
}() }()
connClosed := false rxbuf = fwd.NewReader(conn)
defer func() {
// make sure our version is always sent on the wire, if possible,
// so that peer does not see just closed connection when on rx we see version mismatch.
//
// NOTE if cancelled tx goroutine will wake up without delay.
txWg.Wait()
// don't forget to close conn if returning with error + add handshake err context var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// tx client hello
err := txHello("tx hello", conn, version, encPrefer)
if err != nil { if err != nil {
err = &_HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err} return err
if !connClosed { }
conn.Close()
} // rx server hello reply
var peerVer uint32
peerEnc, peerVer, err = rxHello("rx hello reply", rxbuf)
if err != nil {
return err
}
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
}
return nil
})
if err != nil {
return 0, nil, err
}
// use peer encoding (server should return the same, but we are ok if
// it asks to switch to different)
return peerEnc, rxbuf, nil
}
// handshakeServer implementss server-side handshake.
//
// Server verifies that its version matches Client and accepts client preferred encoding.
func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
defer func() {
if err != nil {
err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err}
} }
}() }()
for i := 0; i < 2; i++ { rxbuf = fwd.NewReader(conn)
select {
case <-ctx.Done(): var peerEnc proto.Encoding
conn.Close() // interrupt IO err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
connClosed = true // rx client hello
return ctx.Err() var peerVer uint32
var err error
case err = <-errch: peerEnc, peerVer, err = rxHello("rx hello", rxbuf)
if err != nil { if err != nil {
return err return err
} }
// tx server reply
//
// do it before version check so that client can also detect "version
// mismatch" instead of just getting "disconnect".
err = txHello("tx hello reply", conn, version, peerEnc)
if err != nil {
return err
}
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
} }
return nil
})
if err != nil {
return 0, nil, err
}
return peerEnc, rxbuf, nil
}
func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) (err error) {
defer xerr.Context(&err, errctx)
var b []byte
switch enc {
case 'N':
// 00 00 00 <v>
b = make([]byte, 4)
if version > 0xff {
panic("encoding N supports versions only in range [0, 0xff]")
}
b[3] = uint8(version)
case 'M':
// (b"NEO", <V>) encoded as msgpack (= 92 c4 03 NEO int(<V>))
b = msgp.AppendArrayHeader(b, 2) // 92
b = msgp.AppendBytes(b, []byte("NEO")) // c4 03 NEO
b = msgp.AppendUint32(b, version) // u?intX version
default:
panic("bug")
}
_, err = conn.Write(b)
if err != nil {
return err
} }
// handshaked ok
return nil return nil
} }
func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32, err error) {
defer xerr.Context(&err, errctx)
b := make([]byte, 4)
_, err = io.ReadFull(rx, b)
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
var peerEnc proto.Encoding
var peerVer uint32
badMagic := false
switch {
case bytes.Equal(b[:3], []byte{0,0,0}):
peerEnc = encN
peerVer = uint32(b[3])
case bytes.Equal(b, []byte{0x92, 0xc4, 3, 'N'}): // start of "fixarray<2> bin8 'N | EO' ...
b = append(b, []byte{0,0}...)
_, err = io.ReadFull(rx, b[4:])
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
if !bytes.Equal(b[4:], []byte{'E','O'}) {
badMagic = true
break
}
peerEnc = encM
rxM := msgp.Reader{R: rx}
peerVer, err = rxM.ReadUint32()
if err != nil {
return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx
}
default:
badMagic = true
}
if badMagic {
return 0, 0, fmt.Errorf("invalid magic %x", b)
}
return peerEnc, peerVer, nil
}
// ---- Dial & Listen at NodeLink level ---- // ---- Dial & Listen at NodeLink level ----
...@@ -141,6 +285,8 @@ func DialLink(ctx context.Context, net xnet.Networker, addr string) (*NodeLink, ...@@ -141,6 +285,8 @@ func DialLink(ctx context.Context, net xnet.Networker, addr string) (*NodeLink,
return nil, err return nil, err
} }
// TODO if handshake fails with "closed" (= might be unexpected encoding)
// -> try redial and handshaking with different encoding (= autodetect encoding)
return _Handshake(ctx, peerConn, _LinkClient) return _Handshake(ctx, peerConn, _LinkClient)
} }
......
...@@ -21,29 +21,47 @@ package neonet ...@@ -21,29 +21,47 @@ package neonet
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"testing" "testing"
"lab.nexedi.com/kirr/go123/exc" "lab.nexedi.com/kirr/go123/exc"
"lab.nexedi.com/kirr/go123/xsync" "lab.nexedi.com/kirr/go123/xsync"
"lab.nexedi.com/kirr/neo/go/neo/proto"
) )
func xhandshake(ctx context.Context, c net.Conn, version uint32) { // xhandshakeClient handshakes as client with encPrefer encoding and verifies that server accepts it.
err := handshake(ctx, c, version) func xhandshakeClient(ctx context.Context, c net.Conn, version uint32, encPrefer proto.Encoding) {
enc, _, err := handshakeClient(ctx, c, version, encPrefer)
exc.Raiseif(err) exc.Raiseif(err)
if enc != encPrefer {
exc.Raisef("enc (%c) != encPrefer (%c)", enc, encPrefer)
}
}
// xhandshakeServer handshakes as server and verifies negotiated encoding to be encOK.
func xhandshakeServer(ctx context.Context, c net.Conn, version uint32, encOK proto.Encoding) {
enc, _, err := handshakeServer(ctx, c, version)
exc.Raiseif(err)
if enc != encOK {
exc.Raisef("enc (%c) != encOK (%c)", enc, encOK)
}
} }
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
Verify(t, _TestHandshake)
}
func _TestHandshake(t *T) {
bg := context.Background() bg := context.Background()
// handshake ok // handshake ok
p1, p2 := net.Pipe() p1, p2 := net.Pipe()
wg := xsync.NewWorkGroup(bg) wg := xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
xhandshake(ctx, p1, 1) xhandshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
xhandshake(ctx, p2, 1) xhandshakeServer(ctx, p2, 1, t.enc)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
...@@ -54,17 +72,17 @@ func TestHandshake(t *testing.T) { ...@@ -54,17 +72,17 @@ func TestHandshake(t *testing.T) {
var err1, err2 error var err1, err2 error
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1) _, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err2 = handshake(ctx, p2, 2) _, _, err2 = handshakeServer(ctx, p2, 2)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
xclose(p2) xclose(p2)
err1Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000002 ; our side = 00000001" err1Want := "pipe - pipe: handshake (client): protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000001 ; our side = 00000002" err2Want := "pipe - pipe: handshake (server): protocol version mismatch: peer = 00000001 ; our side = 00000002"
if !(err1 != nil && err1.Error() == err1Want) { if !(err1 != nil && err1.Error() == err1Want) {
t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want) t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want)
...@@ -78,7 +96,7 @@ func TestHandshake(t *testing.T) { ...@@ -78,7 +96,7 @@ func TestHandshake(t *testing.T) {
err1, err2 = nil, nil err1, err2 = nil, nil
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1) _, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
xclose(p2) xclose(p2)
...@@ -88,16 +106,20 @@ func TestHandshake(t *testing.T) { ...@@ -88,16 +106,20 @@ func TestHandshake(t *testing.T) {
err11, ok := err1.(*_HandshakeError) err11, ok := err1.(*_HandshakeError)
if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) { if !ok || !(errors.Is(err11.Err, io.ErrClosedPipe /* on Write */) || errors.Is(err11.Err, io.ErrUnexpectedEOF /* on Read */)) {
t.Errorf("handshake peer close: unexpected error: %#v", err1) t.Errorf("handshake peer close: unexpected error: %#v", err1)
} }
// XXX same for handshakeServer
// ctx cancel // ctx cancel
// XXX same for handshakeServer
p1, p2 = net.Pipe() p1, p2 = net.Pipe()
ctx, cancel := context.WithCancel(bg) ctx, cancel := context.WithCancel(bg)
wg = xsync.NewWorkGroup(ctx) wg = xsync.NewWorkGroup(ctx)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1) _, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
}) })
tdelay() tdelay()
cancel() cancel()
...@@ -110,5 +132,4 @@ func TestHandshake(t *testing.T) { ...@@ -110,5 +132,4 @@ func TestHandshake(t *testing.T) {
if !ok || !(err11.Err == context.Canceled) { if !ok || !(err11.Err == context.Canceled) {
t.Errorf("handshake cancel: unexpected error: %#v", err1) t.Errorf("handshake cancel: unexpected error: %#v", err1)
} }
} }
...@@ -39,15 +39,17 @@ type pktBuf struct { ...@@ -39,15 +39,17 @@ type pktBuf struct {
data []byte // whole packet data including all headers data []byte // whole packet data including all headers
} }
// Header returns pointer to packet header. // HeaderN returns pointer to packet header in 'N'-encoding.
func (pkt *pktBuf) Header() *proto.PktHeader { func (pkt *pktBuf) Header() *proto.PktHeader { return pkt.HeaderN() } // XXX kill
func (pkt *pktBuf) HeaderN() *proto.PktHeader {
// NOTE no need to check len(.data) < PktHeader: // NOTE no need to check len(.data) < PktHeader:
// .data is always allocated with cap >= PktHeaderLen. // .data is always allocated with cap >= PktHeaderLen.
return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0])) return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0]))
} }
// Payload returns []byte representing packet payload. // PayloadN returns []byte representing packet payload in 'N'-encoding.
func (pkt *pktBuf) Payload() []byte { func (pkt *pktBuf) Payload() []byte { return pkt.PayloadN() } // XXX kill
func (pkt *pktBuf) PayloadN() []byte {
return pkt.data[proto.PktHeaderLen:] return pkt.data[proto.PktHeaderLen:]
} }
...@@ -87,6 +89,7 @@ func (pkt *pktBuf) String() string { ...@@ -87,6 +89,7 @@ func (pkt *pktBuf) String() string {
h := pkt.Header() h := pkt.Header()
s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId)) s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId))
// XXX encN-specific
msgCode := packed.Ntoh16(h.MsgCode) msgCode := packed.Ntoh16(h.MsgCode)
msgLen := packed.Ntoh32(h.MsgLen) msgLen := packed.Ntoh32(h.MsgLen)
data := pkt.Payload() data := pkt.Payload()
...@@ -98,7 +101,7 @@ func (pkt *pktBuf) String() string { ...@@ -98,7 +101,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv // XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg) msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := msg.NEOMsgDecode(data) n, err := encN.NEOMsgDecode(msg, data) // XXX encN hardcoded
if err != nil { if err != nil {
s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data) s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data)
} else { } else {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package proto
// runtime glue for msgpack support
import (
"fmt"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
)
// mstructDecodeError represents decode error when decoder was expecting
// tuple<nfield> for structure named path.
type mstructDecodeError struct {
path string // "Type.field.field"
op msgpack.Op // op we got
opOk msgpack.Op // op expected
}
func (e *mstructDecodeError) Error() string {
return fmt.Sprintf("decode: M: struct %s: got opcode %02x; expect %02x", e.path, e.op, e.opOk)
}
// mdecodeErr is called to normilize error when msgp.ReadXXX returns err when decoding path.
func mdecodeErr(path string, err error) error {
if err == msgp.ErrShortBytes {
return ErrDecodeOverflow
}
return &mdecodeError{path, err}
}
type mdecodeError struct {
path string // "Type.field.field"
err error
}
func (e *mdecodeError) Error() string {
return fmt.Sprintf("decode: M: %s: %s", e.path, e.err)
}
// mOpError represents decode error when decoder faces unexpected operation.
type mOpError struct {
op, opOk msgpack.Op // op we got and what was expected
}
func (e *mOpError) Error() string {
return fmt.Sprintf("expected opcode %02x; got %02x", e.opOk, e.op)
}
func mdecodeOpErr(path string, op, opOk msgpack.Op) error {
return mdecodeErr(path+"/op", &mOpError{op, opOk})
}
// mLen8Error represents decode error when decoder faces unexpected length in Bin8.
type mLen8Error struct {
l, lOk byte // len we got and expected
}
func (e *mLen8Error) Error() string {
return fmt.Sprintf("expected length %d; got %d", e.lOk, e.l)
}
func mdecodeLen8Err(path string, l, lOk uint8) error {
return mdecodeErr(path+"/len", &mLen8Error{l, lOk})
}
func mdecodeEnumTypeErr(path string, enumType, enumTypeOk byte) error {
return mdecodeErr(path+"/enumType",
fmt.Errorf("expected %d; got %d", enumTypeOk, enumType))
}
func mdecodeEnumValueErr(path string, v byte) error {
return mdecodeErr(path, fmt.Errorf("invalid enum payload %02x", v))
}
...@@ -32,6 +32,12 @@ import ( ...@@ -32,6 +32,12 @@ import (
"time" "time"
) )
// MsgCode returns code the corresponds to type of the message.
// XXX place - ok?
func MsgCode(msg Msg) uint16 {
return msg.neoMsgCode()
}
// MsgType looks up message type by message code. // MsgType looks up message type by message code.
// //
// Nil is returned if message code is not valid. // Nil is returned if message code is not valid.
......
...@@ -84,6 +84,7 @@ const ( ...@@ -84,6 +84,7 @@ const (
Version = 6 Version = 6
// length of packet header // length of packet header
// XXX encN-specific ?
PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr) PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr)
// packets larger than PktMaxSize are not allowed. // packets larger than PktMaxSize are not allowed.
...@@ -99,6 +100,7 @@ const ( ...@@ -99,6 +100,7 @@ const (
INVALID_OID zodb.Oid = 1<<64 - 1 INVALID_OID zodb.Oid = 1<<64 - 1
) )
// XXX encN-specific ?
// PktHeader represents header of a raw packet. // PktHeader represents header of a raw packet.
// //
// A packet contains connection ID and message. // A packet contains connection ID and message.
...@@ -110,31 +112,75 @@ type PktHeader struct { ...@@ -110,31 +112,75 @@ type PktHeader struct {
MsgLen packed.BE32 // payload message length (excluding packet header) MsgLen packed.BE32 // payload message length (excluding packet header)
} }
// Msg is the interface implemented by all NEO messages. // Msg is the interface representing a NEO message.
type Msg interface { type Msg interface {
// marshal/unmarshal into/from wire format: // marshal/unmarshal into/from wire format:
// NEOMsgCode returns message code needed to be used for particular message type // neoMsgCode returns message code needed to be used for particular message type
// on the wire. // on the wire.
NEOMsgCode() uint16 neoMsgCode() uint16
// NEOMsgEncodedLen returns how much space is needed to encode current message payload. // for encoding E:
NEOMsgEncodedLen() int //
// - neoMsgEncodedLen<E> returns how much space is needed to encode current message payload via E encoding.
// NEOMsgEncode encodes current message state into buf. //
// - neoMsgEncode<E> encodes current message state into buf via E encoding.
// //
// len(buf) must be >= neoMsgEncodedLen(). // len(buf) must be >= neoMsgEncodedLen<E>().
NEOMsgEncode(buf []byte) //
// - neoMsgDecode<E> decodes data via E encoding into message in-place.
// N encoding (original struct-based encoding)
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
// NEOMsgDecode decodes data into message in-place. // M encoding (via MessagePack)
NEOMsgDecode(data []byte) (nread int, err error) neoMsgEncodedLenM() int
neoMsgEncodeM(buf []byte)
neoMsgDecodeM(data []byte) (nread int, err error)
}
// Encoding represents messages encoding.
type Encoding byte
// XXX drop "NEO" prefix?
// NEOMsgEncodedLen returns how much space is needed to encode msg payload via encoding e.
func (e Encoding) NEOMsgEncodedLen(msg Msg) int {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgEncodedLenN()
case 'M': return msg.neoMsgEncodedLenM()
}
} }
// NEOMsgEncode encodes msg state into buf via encoding e.
//
// len(buf) must be >= e.NEOMsgEncodedLen(m).
func (e Encoding) NEOMsgEncode(msg Msg, buf []byte) {
switch e {
default: panic("bug")
case 'N': msg.neoMsgEncodeN(buf)
case 'M': msg.neoMsgEncodeM(buf)
}
}
// NEOMsgDecode decodes data via encoding e into msg in-place.
func (e Encoding) NEOMsgDecode(msg Msg, data []byte) (nread int, err error) {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgDecodeN(data)
case 'M': return msg.neoMsgDecodeM(data)
}
}
// ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow // ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow
var ErrDecodeOverflow = errors.New("decode: buffer overflow") var ErrDecodeOverflow = errors.New("decode: buffer overflow")
// ---- messages ---- // ---- messages ----
//neo:proto enum
type ErrorCode uint32 type ErrorCode uint32
const ( const (
ACK ErrorCode = iota ACK ErrorCode = iota
...@@ -155,6 +201,7 @@ const ( ...@@ -155,6 +201,7 @@ const (
// XXX move this to neo.clusterState wrapping proto.ClusterState? // XXX move this to neo.clusterState wrapping proto.ClusterState?
//trace:event traceClusterStateChanged(cs *ClusterState) //trace:event traceClusterStateChanged(cs *ClusterState)
//neo:proto enum
type ClusterState int8 type ClusterState int8
const ( const (
// The cluster is initially in the RECOVERING state, and it goes back to // The cluster is initially in the RECOVERING state, and it goes back to
...@@ -188,6 +235,7 @@ const ( ...@@ -188,6 +235,7 @@ const (
STOPPING_BACKUP STOPPING_BACKUP
) )
//neo:proto enum
type NodeType int8 type NodeType int8
const ( const (
MASTER NodeType = iota MASTER NodeType = iota
...@@ -196,6 +244,7 @@ const ( ...@@ -196,6 +244,7 @@ const (
ADMIN ADMIN
) )
//neo:proto enum
type NodeState int8 type NodeState int8
const ( const (
UNKNOWN NodeState = iota //short: U // XXX tag prefix name ? UNKNOWN NodeState = iota //short: U // XXX tag prefix name ?
...@@ -204,6 +253,7 @@ const ( ...@@ -204,6 +253,7 @@ const (
PENDING //short: P PENDING //short: P
) )
//neo:proto enum
type CellState int8 type CellState int8
const ( const (
// Write-only cell. Last transactions are missing because storage is/was down // Write-only cell. Last transactions are missing because storage is/was down
...@@ -255,7 +305,7 @@ type Address struct { ...@@ -255,7 +305,7 @@ type Address struct {
} }
// NOTE if Host == "" -> Port not added to wire (see py.PAddress): // NOTE if Host == "" -> Port not added to wire (see py.PAddress):
func (a *Address) neoEncodedLen() int { func (a *Address) neoEncodedLenN() int {
l := string_neoEncodedLen(a.Host) l := string_neoEncodedLen(a.Host)
if a.Host != "" { if a.Host != "" {
l += 2 l += 2
...@@ -263,7 +313,7 @@ func (a *Address) neoEncodedLen() int { ...@@ -263,7 +313,7 @@ func (a *Address) neoEncodedLen() int {
return l return l
} }
func (a *Address) neoEncode(b []byte) int { func (a *Address) neoEncodeN(b []byte) int {
n := string_neoEncode(a.Host, b[0:]) n := string_neoEncode(a.Host, b[0:])
if a.Host != "" { if a.Host != "" {
binary.BigEndian.PutUint16(b[n:], a.Port) binary.BigEndian.PutUint16(b[n:], a.Port)
...@@ -272,7 +322,7 @@ func (a *Address) neoEncode(b []byte) int { ...@@ -272,7 +322,7 @@ func (a *Address) neoEncode(b []byte) int {
return n return n
} }
func (a *Address) neoDecode(b []byte) (uint64, bool) { func (a *Address) neoDecodeN(b []byte) (uint64, bool) {
n, ok := string_neoDecode(&a.Host, b) n, ok := string_neoDecode(&a.Host, b)
if !ok { if !ok {
return 0, false return 0, false
...@@ -295,17 +345,17 @@ type Checksum [20]byte ...@@ -295,17 +345,17 @@ type Checksum [20]byte
// PTid is Partition Table identifier. // PTid is Partition Table identifier.
// //
// Zero value means "invalid id" (<-> None in py.PPTID) // Zero value means "invalid id" (<-> None in py.PPTID) XXX = nil in msgpack
type PTid uint64 type PTid uint64
// IdTime represents time of identification. // IdTime represents time of identification.
type IdTime float64 type IdTime float64
func (t IdTime) neoEncodedLen() int { func (t IdTime) neoEncodedLenN() int {
return 8 return 8
} }
func (t IdTime) neoEncode(b []byte) int { func (t IdTime) neoEncodeN(b []byte) int {
// use -inf as value for no data (NaN != NaN -> hard to use NaN in tests) // use -inf as value for no data (NaN != NaN -> hard to use NaN in tests)
// NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer // NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer
tt := float64(t) tt := float64(t)
...@@ -316,7 +366,7 @@ func (t IdTime) neoEncode(b []byte) int { ...@@ -316,7 +366,7 @@ func (t IdTime) neoEncode(b []byte) int {
return 8 return 8
} }
func (t *IdTime) neoDecode(data []byte) (uint64, bool) { func (t *IdTime) neoDecodeN(data []byte) (uint64, bool) {
if len(data) < 8 { if len(data) < 8 {
return 0, false return 0, false
} }
...@@ -438,8 +488,8 @@ type Recovery struct { ...@@ -438,8 +488,8 @@ type Recovery struct {
type AnswerRecovery struct { type AnswerRecovery struct {
PTid PTid
BackupTid zodb.Tid BackupTid zodb.Tid // XXX nil <-> 0
TruncateTid zodb.Tid TruncateTid zodb.Tid // XXX nil <-> 0
} }
// Ask the last OID/TID so that a master can initialize its TransactionManager. // Ask the last OID/TID so that a master can initialize its TransactionManager.
...@@ -1199,13 +1249,13 @@ type FlushLog struct {} ...@@ -1199,13 +1249,13 @@ type FlushLog struct {}
// ---- runtime support for protogen and custom codecs ---- // ---- runtime support for protogen and custom codecs ----
// customCodec is the interface that is implemented by types with custom encodings. // customCodecN is the interface that is implemented by types with custom N encodings.
// //
// its semantic is very similar to Msg. // its semantic is very similar to Msg.
type customCodec interface { type customCodecN interface {
neoEncodedLen() int neoEncodedLenN() int
neoEncode(buf []byte) (nwrote int) neoEncodeN(buf []byte) (nwrote int)
neoDecode(data []byte) (nread uint64, ok bool) // XXX uint64 or int here? neoDecodeN(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
} }
func byte2bool(b byte) bool { func byte2bool(b byte) bool {
......
...@@ -79,31 +79,32 @@ func TestPktHeader(t *testing.T) { ...@@ -79,31 +79,32 @@ func TestPktHeader(t *testing.T) {
} }
// test marshalling for one message type // test marshalling for one message type
func testMsgMarshal(t *testing.T, msg Msg, encoded string) { func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) {
typ := reflect.TypeOf(msg).Elem() // type of *msg typ := reflect.TypeOf(msg).Elem() // type of *msg
msg2 := reflect.New(typ).Interface().(Msg) msg2 := reflect.New(typ).Interface().(Msg)
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
t.Errorf("%v: panic ↓↓↓:", typ) t.Errorf("%c/%v: panic ↓↓↓:", enc, typ)
panic(e) // to show traceback panic(e) // to show traceback
} }
}() }()
// msg.encode() == expected // msg.encode() == expected
msgCode := msg.NEOMsgCode() msgCode := msg.neoMsgCode()
n := msg.NEOMsgEncodedLen() n := enc.NEOMsgEncodedLen(msg)
msgType := MsgType(msgCode) msgType := MsgType(msgCode)
if msgType != typ { if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType) t.Errorf("%c/%v: msgCode = %v which corresponds to %v", enc, typ, msgCode, msgType)
} }
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded)) t.Errorf("%c/%v: encodedLen = %v ; want %v", enc, typ, n, len(encoded))
} }
buf := make([]byte, n) buf := make([]byte, n)
msg.NEOMsgEncode(buf) enc.NEOMsgEncode(msg, buf)
if string(buf) != encoded { if string(buf) != encoded {
t.Errorf("%v: encode result unexpected:", typ) t.Errorf("%c/%v: encode result unexpected:", enc, typ)
t.Errorf("\thave: %s", hexpkg.EncodeToString(buf)) t.Errorf("\thave: %s", hexpkg.EncodeToString(buf))
t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded))) t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded)))
} }
...@@ -112,7 +113,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -112,7 +113,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
for l := len(buf) - 1; l >= 0; l-- { for l := len(buf) - 1; l >= 0; l-- {
func() { func() {
defer func() { defer func() {
subj := fmt.Sprintf("%v: encode(buf[:encodedLen-%v])", typ, len(encoded)-l) subj := fmt.Sprintf("%c/%v: encode(buf[:encodedLen-%v])", enc, typ, len(encoded)-l)
e := recover() e := recover()
if e == nil { if e == nil {
t.Errorf("%s did not panic", subj) t.Errorf("%s did not panic", subj)
...@@ -131,29 +132,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -131,29 +132,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
} }
}() }()
msg.NEOMsgEncode(buf[:l]) enc.NEOMsgEncode(msg, buf[:l])
}() }()
} }
// msg.decode() == expected // msg.decode() == expected
data := []byte(encoded + "noise") data := []byte(encoded + "noise")
n, err := msg2.NEOMsgDecode(data) n, err := enc.NEOMsgDecode(msg2, data)
if err != nil { if err != nil {
t.Errorf("%v: decode error %v", typ, err) t.Errorf("%c/%v: decode error %v", enc, typ, err)
} }
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: nread = %v ; want %v", typ, n, len(encoded)) t.Errorf("%c/%v: nread = %v ; want %v", enc, typ, n, len(encoded))
} }
if !reflect.DeepEqual(msg2, msg) { if !reflect.DeepEqual(msg2, msg) {
t.Errorf("%v: decode result unexpected: %v ; want %v", typ, msg2, msg) t.Errorf("%c/%v: decode result unexpected: %v ; want %v", enc, typ, msg2, msg)
} }
// decode must detect buffer overflow // decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- { for l := len(encoded) - 1; l >= 0; l-- {
n, err = msg2.NEOMsgDecode(data[:l]) n, err = enc.NEOMsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%v: decode overflow not detected on [:%v]", typ, l) t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l)
} }
} }
...@@ -162,14 +163,21 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -162,14 +163,21 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
// test encoding/decoding of messages // test encoding/decoding of messages
func TestMsgMarshal(t *testing.T) { func TestMsgMarshal(t *testing.T) {
var testv = []struct { var testv = []struct {
msg Msg msg Msg
encoded string // []byte encodedN string // []byte
encodedM string // []byte
}{ }{
// empty // empty
{&Ping{}, ""}, {&Ping{},
"",
"\x90",
},
// uint32, string // uint32(N)/enum(M), string
{&Error{Code: 0x01020304, Message: "hello"}, "\x01\x02\x03\x04\x00\x00\x00\x05hello"}, {&Error{Code: 0x00000045, Message: "hello"},
"\x00\x00\x00\x45\x00\x00\x00\x05hello",
hex("92") + hex("d40045") + "\xc4\x05hello",
},
// Oid, Tid, bool, Checksum, []byte // Oid, Tid, bool, Checksum, []byte
{&StoreObject{ {&StoreObject{
...@@ -185,7 +193,18 @@ func TestMsgMarshal(t *testing.T) { ...@@ -185,7 +193,18 @@ func TestMsgMarshal(t *testing.T) {
hex("01020304050607080a0b0c0d0e0f010200") + hex("01020304050607080a0b0c0d0e0f010200") +
hex("0102030405060708090a0b0c0d0e0f1011121314") + hex("0102030405060708090a0b0c0d0e0f1011121314") +
hex("0000000b") + "hello world" + hex("0000000b") + "hello world" +
hex("0a0b0c0d0e0f01030a0b0c0d0e0f0104")}, hex("0a0b0c0d0e0f01030a0b0c0d0e0f0104"),
// M
hex("97") +
hex("c408") + hex("0102030405060708") +
hex("c408") + hex("0a0b0c0d0e0f0102") +
hex("c2") +
hex("c414") + hex("0102030405060708090a0b0c0d0e0f1011121314") +
hex("c40b") + "hello world" +
hex("c408") + hex("0a0b0c0d0e0f0103") +
hex("c408") + hex("0a0b0c0d0e0f0104"),
},
// PTid, [] (of [] of {UUID, CellState}) // PTid, [] (of [] of {UUID, CellState})
{&AnswerPartitionTable{ {&AnswerPartitionTable{
...@@ -198,12 +217,22 @@ func TestMsgMarshal(t *testing.T) { ...@@ -198,12 +217,22 @@ func TestMsgMarshal(t *testing.T) {
}, },
}, },
// N
hex("0102030405060708") + hex("0102030405060708") +
hex("00000022") + hex("00000022") +
hex("00000003") + hex("00000003") +
hex("000000020000000b010000001100") + hex("000000020000000b010000001100") +
hex("000000010000000b02") + hex("000000010000000b02") +
hex("000000030000000b030000000f040000001701"), hex("000000030000000b030000000f040000001701"),
// M
hex("93") +
hex("cf0102030405060708") +
hex("22") +
hex("93") +
hex("91"+"92"+"920bd40401"+"9211d40400") +
hex("91"+"91"+"920bd40402") +
hex("91"+"93"+"920bd40403"+"920fd40404"+"9217d40401"),
}, },
// map[Oid]struct {Tid,Tid,bool} // map[Oid]struct {Tid,Tid,bool}
...@@ -219,11 +248,20 @@ func TestMsgMarshal(t *testing.T) { ...@@ -219,11 +248,20 @@ func TestMsgMarshal(t *testing.T) {
5: {4, 3, true}, 5: {4, 3, true},
}}, }},
// N
u32(4) + u32(4) +
u64(1) + u64(1) + u64(0) + hex("00") + u64(1) + u64(1) + u64(0) + hex("00") +
u64(2) + u64(7) + u64(1) + hex("01") + u64(2) + u64(7) + u64(1) + hex("01") +
u64(5) + u64(4) + u64(3) + hex("01") + u64(5) + u64(4) + u64(3) + hex("01") +
u64(8) + u64(7) + u64(1) + hex("00"), u64(8) + u64(7) + u64(1) + hex("00"),
// M
hex("91") +
hex("84") +
hex("c408")+u64(1) + hex("93") + hex("c408")+u64(1) + hex("c408")+u64(0) + hex("c2") +
hex("c408")+u64(2) + hex("93") + hex("c408")+u64(7) + hex("c408")+u64(1) + hex("c3") +
hex("c408")+u64(5) + hex("93") + hex("c408")+u64(4) + hex("c408")+u64(3) + hex("c3") +
hex("c408")+u64(8) + hex("93") + hex("c408")+u64(7) + hex("c408")+u64(1) + hex("c2"),
}, },
// map[uint32]UUID + trailing ... // map[uint32]UUID + trailing ...
...@@ -238,41 +276,86 @@ func TestMsgMarshal(t *testing.T) { ...@@ -238,41 +276,86 @@ func TestMsgMarshal(t *testing.T) {
MaxTID: 128, MaxTID: 128,
}, },
// N
u32(4) + u32(4) +
u32(1) + u32(7) + u32(1) + u32(7) +
u32(2) + u32(9) + u32(2) + u32(9) +
u32(4) + u32(17) + u32(4) + u32(17) +
u32(7) + u32(3) + u32(7) + u32(3) +
u64(23) + u64(128), u64(23) + u64(128),
// M
hex("93") +
hex("84") +
hex("01" + "07") +
hex("02" + "09") +
hex("04" + "11") +
hex("07" + "03") +
hex("c408") + u64(23) +
hex("c408") + u64(128),
}, },
// uint32, []uint32 // uint32, []uint32
{&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}}, {&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}},
// N
u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4), u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4),
// M
hex("92") +
hex("07") +
hex("94") +
hex("01030904"),
}, },
// uint32, Address, string, IdTime // uint32, Address, string, IdTime
{&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} }, {&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} },
// N
u8(2) + u32(17) + u32(9) + u8(2) + u32(17) + u32(9) +
"localhost" + u16(7777) + "localhost" + u16(7777) +
u32(6) + "myname" + u32(6) + "myname" +
hex("3fbf9add1091c895") + hex("3fbf9add1091c895") +
u32(2) + u32(5)+"room1" + u32(7)+"rack234" + u32(2) + u32(5)+"room1" + u32(7)+"rack234" +
u32(3) + u32(3)+u32(4)+u32(5), u32(3) + u32(3)+u32(4)+u32(5),
// M
hex("97") +
hex("d40202") +
hex("11") +
hex("92") + hex("c409")+"localhost" + hex("cd")+u16(7777) +
hex("c406")+"myname" +
hex("cb" + "3fbf9add1091c895") +
hex("92") + hex("c405")+"room1" + hex("c407")+"rack234" +
hex("93") + hex("030405"),
}, },
// IdTime, empty Address, int32 // IdTime, empty Address, int32
{&NotifyNodeInformation{1504466245.926185, []NodeInfo{ {&NotifyNodeInformation{1504466245.926185, []NodeInfo{
{CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}}, {CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}},
// N
hex("41d66b15517b469d") + u32(1) + hex("41d66b15517b469d") + u32(1) +
u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) + u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) +
hex("41d66b15517b3d04"), hex("41d66b15517b3d04"),
// M
hex("92") +
hex("cb" + "41d66b15517b469d") +
hex("91") +
hex("95") +
hex("d40202") +
hex("92" + "c400"+"" + "00") +
hex("d2" + "e0000001") +
hex("d40302") +
hex("cb" + "41d66b15517b3d04"),
}, },
// empty IdTime // empty IdTime
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}}, hex("ffffffffffffffff") + hex("00000000")}, {&NotifyNodeInformation{IdTimeNone, []NodeInfo{}},
// N
hex("ffffffffffffffff") + hex("00000000"),
// M
hex("92") +
hex("cb" + "fff0000000000000") + // XXX nan/-inf not handled yet
hex("90"),
},
// TODO we need tests for: // TODO we need tests for:
// []varsize + trailing // []varsize + trailing
...@@ -280,7 +363,8 @@ func TestMsgMarshal(t *testing.T) { ...@@ -280,7 +363,8 @@ func TestMsgMarshal(t *testing.T) {
} }
for _, tt := range testv { for _, tt := range testv {
testMsgMarshal(t, tt.msg, tt.encoded) testMsgMarshal(t, 'N', tt.msg, tt.encodedN)
testMsgMarshal(t, 'M', tt.msg, tt.encodedM)
} }
} }
...@@ -288,18 +372,23 @@ func TestMsgMarshal(t *testing.T) { ...@@ -288,18 +372,23 @@ func TestMsgMarshal(t *testing.T) {
// this way we additionally lightly check encode / decode overflow behaviour for all types. // this way we additionally lightly check encode / decode overflow behaviour for all types.
func TestMsgMarshalAllOverflowLightly(t *testing.T) { func TestMsgMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range msgTypeRegistry { for _, typ := range msgTypeRegistry {
// zero-value for a type for _, enc := range []Encoding{'N', 'M'} {
msg := reflect.New(typ).Interface().(Msg) // zero-value for a type
l := msg.NEOMsgEncodedLen() msg := reflect.New(typ).Interface().(Msg)
zerol := make([]byte, l) l := enc.NEOMsgEncodedLen(msg)
// decoding will turn nil slice & map into empty allocated ones. zerol := make([]byte, l)
// we need it so that reflect.DeepEqual works for msg encode/decode comparison if enc != 'N' { // M-encoding of zero-value is not all zeros
n, err := msg.NEOMsgDecode(zerol) enc.NEOMsgEncode(msg, zerol)
if !(n == l && err == nil) { }
t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l) // decoding will turn nil slice & map into empty allocated ones.
} // we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := enc.NEOMsgDecode(msg, zerol)
if !(n == l && err == nil) {
t.Errorf("%c/%v: zero-decode unexpected: %v, %v ; want %v, nil", enc, typ, n, err, l)
}
testMsgMarshal(t, msg, string(zerol)) testMsgMarshal(t, enc, msg, string(zerol))
}
} }
} }
...@@ -316,6 +405,8 @@ func TestMsgDecodeLenOverflow(t *testing.T) { ...@@ -316,6 +405,8 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
{&AnswerLockedTransactions{}, u32(0x10000000)}, {&AnswerLockedTransactions{}, u32(0x10000000)},
} }
enc := Encoding('N') // XXX hardcoded XXX + M-variants with big len?
for _, tt := range testv { for _, tt := range testv {
data := []byte(tt.data) data := []byte(tt.data)
func() { func() {
...@@ -325,7 +416,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) { ...@@ -325,7 +416,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
} }
}() }()
n, err := tt.msg.NEOMsgDecode(data) n, err := enc.NEOMsgDecode(tt.msg, data)
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data, t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data,
n, err, 0, ErrDecodeOverflow) n, err, 0, ErrDecodeOverflow)
......
...@@ -25,10 +25,11 @@ NEO. Protocol module. Code generator ...@@ -25,10 +25,11 @@ NEO. Protocol module. Code generator
This program generates marshalling code for message types defined in proto.go . This program generates marshalling code for message types defined in proto.go .
For every type 4 methods are generated in accordance with neo.Msg interface: For every type 4 methods are generated in accordance with neo.Msg interface:
NEOMsgCode() uint16 // XXX update for 'N' and 'M'
NEOMsgEncodedLen() int neoMsgCode() uint16
NEOMsgEncode(buf []byte) neoMsgEncodedLenN() int
NEOMsgDecode(data []byte) (nread int, err error) neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
List of message types is obtained via searching through proto.go AST - looking List of message types is obtained via searching through proto.go AST - looking
for appropriate struct declarations there. for appropriate struct declarations there.
...@@ -40,7 +41,7 @@ maps, ...). ...@@ -40,7 +41,7 @@ maps, ...).
Top-level generation driver is in generateCodecCode(). It accepts type Top-level generation driver is in generateCodecCode(). It accepts type
specification and something that performs actual leaf-nodes code generation specification and something that performs actual leaf-nodes code generation
(CodeGenerator interface). There are 3 particular codegenerators implemented - (CodeGenerator interface). There are 3 particular codegenerators implemented -
- sizer, encoder & decoder - to generate each of the needed method functions. - sizerN, encoderN & decoder - to generate each of the needed method functions. XXX N/M
The structure of whole process is very similar to what would be happening at The structure of whole process is very similar to what would be happening at
runtime if marshalling was reflect based, but statically with go/types we don't runtime if marshalling was reflect based, but statically with go/types we don't
...@@ -77,6 +78,8 @@ import ( ...@@ -77,6 +78,8 @@ import (
"os" "os"
"sort" "sort"
"strings" "strings"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
) )
// parsed & typechecked input // parsed & typechecked input
...@@ -116,8 +119,16 @@ func typeName(typ types.Type) string { ...@@ -116,8 +119,16 @@ func typeName(typ types.Type) string {
return types.TypeString(typ, qf) return types.TypeString(typ, qf)
} }
var neo_customCodec *types.Interface // type of neo.customCodec // zodb.Tid and zodb.Oid types
var memBuf types.Type // type of mem.Buf var zodbTid types.Type
var zodbOid types.Type
var neo_customCodecN *types.Interface // type of neo.customCodecN
var memBuf types.Type // type of mem.Buf
// registry of enums
var enumRegistry = map[types.Type]int{} // type -> enum type serial
// bytes.Buffer + bell & whistles // bytes.Buffer + bell & whistles
type Buffer struct { type Buffer struct {
...@@ -181,6 +192,7 @@ func loadPkg(pkgPath string, sources ...string) *types.Package { ...@@ -181,6 +192,7 @@ func loadPkg(pkgPath string, sources ...string) *types.Package {
type Annotation struct { type Annotation struct {
typeonly bool typeonly bool
answer bool answer bool
enum bool
} }
// parse checks doc for specific comment annotations and, if present, loads them. // parse checks doc for specific comment annotations and, if present, loads them.
...@@ -211,6 +223,12 @@ func (a *Annotation) parse(doc *ast.CommentGroup) { ...@@ -211,6 +223,12 @@ func (a *Annotation) parse(doc *ast.CommentGroup) {
} }
a.answer = true a.answer = true
case "enum":
if a.enum {
log.Fatalf("%v: duplicate `enum`", cpos)
}
a.enum = true
default: default:
log.Fatalf("%v: unknown neo:proto directive %q", cpos, arg) log.Fatalf("%v: unknown neo:proto directive %q", cpos, arg)
} }
...@@ -243,6 +261,14 @@ func (v BySerial) Len() int { return len(v) } ...@@ -243,6 +261,14 @@ func (v BySerial) Len() int { return len(v) }
// ---------------------------------------- // ----------------------------------------
func xlookup(pkg *types.Package, name string) types.Object {
obj := pkg.Scope().Lookup(name)
if obj == nil {
log.Fatalf("cannot find `%s.%s`", pkg.Name(), name)
}
return obj
}
func main() { func main() {
var err error var err error
...@@ -252,15 +278,12 @@ func main() { ...@@ -252,15 +278,12 @@ func main() {
zodbPkg = loadPkg("lab.nexedi.com/kirr/neo/go/zodb", "../../zodb/zodb.go") zodbPkg = loadPkg("lab.nexedi.com/kirr/neo/go/zodb", "../../zodb/zodb.go")
protoPkg = loadPkg("lab.nexedi.com/kirr/neo/go/neo/proto", "proto.go") protoPkg = loadPkg("lab.nexedi.com/kirr/neo/go/neo/proto", "proto.go")
// extract neo.customCodec // extract neo.customCodecN
cc := protoPkg.Scope().Lookup("customCodec") cc := xlookup(protoPkg, "customCodecN")
if cc == nil {
log.Fatal("cannot find `customCodec`")
}
var ok bool var ok bool
neo_customCodec, ok = cc.Type().Underlying().(*types.Interface) neo_customCodecN, ok = cc.Type().Underlying().(*types.Interface)
if !ok { if !ok {
log.Fatal("customCodec is not interface (got %v)", cc.Type()) log.Fatal("customCodecN is not interface (got %v)", cc.Type())
} }
// extract mem.Buf // extract mem.Buf
...@@ -282,6 +305,10 @@ func main() { ...@@ -282,6 +305,10 @@ func main() {
} }
memBuf = __.Type() memBuf = __.Type()
// extract zodb.Tid and zodb.Oid
zodbTid = xlookup(zodbPkg, "Tid").Type()
zodbOid = xlookup(zodbPkg, "Oid").Type()
// prologue // prologue
f := fileMap["proto.go"] f := fileMap["proto.go"]
buf := Buffer{} buf := Buffer{}
...@@ -295,6 +322,9 @@ import ( ...@@ -295,6 +322,9 @@ import (
"reflect" "reflect"
"sort" "sort"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
"lab.nexedi.com/kirr/go123/mem" "lab.nexedi.com/kirr/go123/mem"
"lab.nexedi.com/kirr/neo/go/zodb" "lab.nexedi.com/kirr/neo/go/zodb"
)`) )`)
...@@ -304,6 +334,7 @@ import ( ...@@ -304,6 +334,7 @@ import (
// go over message types declaration and generate marshal code for them // go over message types declaration and generate marshal code for them
buf.emit("// messages marshalling\n") buf.emit("// messages marshalling\n")
msgSerial := 0 msgSerial := 0
enumSerial := 0
for _, decl := range f.Decls { for _, decl := range f.Decls {
// we look for types (which can be only under GenDecl) // we look for types (which can be only under GenDecl)
gendecl, ok := decl.(*ast.GenDecl) gendecl, ok := decl.(*ast.GenDecl)
...@@ -324,16 +355,25 @@ import ( ...@@ -324,16 +355,25 @@ import (
typespec := spec.(*ast.TypeSpec) // must be because tok = TYPE typespec := spec.(*ast.TypeSpec) // must be because tok = TYPE
typename := typespec.Name.Name typename := typespec.Name.Name
// we are only interested in struct types
if _, ok := typespec.Type.(*ast.StructType); !ok {
continue
}
// `//neo:proto ...` annotation for this particular type // `//neo:proto ...` annotation for this particular type
specAnnotation := declAnnotation // inheriting from decl specAnnotation := declAnnotation // inheriting from decl
specAnnotation.parse(typespec.Doc) specAnnotation.parse(typespec.Doc)
// type only -> don't generate message interface for it // remember enum types
// FIXME separate dedicated first pass to extract enums first
if specAnnotation.enum {
//typ := typeInfo.Types[typespec.Type].Type
typ := typeInfo.Defs[typespec.Name].Type()
// XXX verify typ is basic int XXX or byte?
enumRegistry[typ]= enumSerial
//fmt.Printf("// enum %s #%d\n", typeName(typ), enumSerial)
enumSerial++
}
// messages are only struct types without typeonly annotation
if _, ok := typespec.Type.(*ast.StructType); !ok {
continue
}
if specAnnotation.typeonly { if specAnnotation.typeonly {
continue continue
} }
...@@ -350,13 +390,18 @@ import ( ...@@ -350,13 +390,18 @@ import (
fmt.Fprintf(&buf, "// %s. %s\n\n", msgCode, typename) fmt.Fprintf(&buf, "// %s. %s\n\n", msgCode, typename)
buf.emit("func (*%s) NEOMsgCode() uint16 {", typename) buf.emit("func (*%s) neoMsgCode() uint16 {", typename)
buf.emit("return %s", msgCode) buf.emit("return %s", msgCode)
buf.emit("}\n") buf.emit("}\n")
buf.WriteString(generateCodecCode(typespec, &sizer{})) buf.WriteString(generateCodecCode(typespec, &sizerN{}))
buf.WriteString(generateCodecCode(typespec, &encoder{})) buf.WriteString(generateCodecCode(typespec, &encoderN{}))
buf.WriteString(generateCodecCode(typespec, &decoder{})) buf.WriteString(generateCodecCode(typespec, &decoderN{}))
// XXX keep all M routines separate from N for code locality
buf.WriteString(generateCodecCode(typespec, &sizerM{}))
buf.WriteString(generateCodecCode(typespec, &encoderM{}))
buf.WriteString(generateCodecCode(typespec, &decoderM{}))
msgTypeRegistry[msgCode] = typename msgTypeRegistry[msgCode] = typename
msgSerial++ msgSerial++
...@@ -382,9 +427,12 @@ import ( ...@@ -382,9 +427,12 @@ import (
// format & output generated code // format & output generated code
code, err := format.Source(buf.Bytes()) code, err := format.Source(buf.Bytes())
//code = buf.Bytes()
if true {
if err != nil { if err != nil {
panic(err) // should not happen panic(err) // should not happen
} }
}
_, err = os.Stdout.Write(code) _, err = os.Stdout.Write(code)
if err != nil { if err != nil {
...@@ -394,13 +442,13 @@ import ( ...@@ -394,13 +442,13 @@ import (
// info about encode/decode of a basic fixed-size type // info about encode/decode of a basic fixed-size type
type basicCodec struct { type basicCodecN struct {
wireSize int wireSize int
encode string encode string
decode string decode string
} }
var basicTypes = map[types.BasicKind]basicCodec{ var basicTypesN = map[types.BasicKind]basicCodecN{
// encode: %v %v will be `data[n:]`, value // encode: %v %v will be `data[n:]`, value
// decode: %v will be `data[n:]` (and already made sure data has more enough bytes to read) // decode: %v will be `data[n:]` (and already made sure data has more enough bytes to read)
types.Bool: {1, "(%v)[0] = bool2byte(%v)", "byte2bool((%v)[0])"}, types.Bool: {1, "(%v)[0] = bool2byte(%v)", "byte2bool((%v)[0])"},
...@@ -417,18 +465,47 @@ var basicTypes = map[types.BasicKind]basicCodec{ ...@@ -417,18 +465,47 @@ var basicTypes = map[types.BasicKind]basicCodec{
types.Float64: {8, "float64_neoEncode(%v, %v)", "float64_neoDecode(%v)"}, types.Float64: {8, "float64_neoEncode(%v, %v)", "float64_neoDecode(%v)"},
} }
// does a type have fixed wire size and, if yes, what it is? // does a type have fixed wire size when encoded and, if yes, what it is?
func typeSizeFixed(typ types.Type) (wireSize int, ok bool) { func typeEncodingSizeFixed(encoding byte, typ types.Type) (wireSize int, ok bool) {
return typeRxTxSizeFixed(encoding, typ, false)
}
// does a type always have fixed wire size when we are decoding it from
// a packet received from outside? if yes, what size it is?
func typeDecodingSizeFixed(encoding byte, typ types.Type) (wireSize int, ok bool) {
return typeRxTxSizeFixed(encoding, typ, true)
}
func typeRxTxSizeFixed(encoding byte, typ types.Type, rx bool) (wireSize int, ok bool) {
switch encoding {
default:
panic("bad encoding")
case 'M':
// pass typ through sizerM and see if encoded size is fixed or not
// XXX make something not fixed when rx=true?
s := &sizerM{}
codegenType("x", typ, nil, s)
if !s.size.IsNumeric() { // no symbolic part
return 0, false
}
return s.size.num, true
case 'N':
// implemented below
// XXX also pass through sizerN ?
}
switch u := typ.Underlying().(type) { switch u := typ.Underlying().(type) {
case *types.Basic: case *types.Basic:
basic, ok := basicTypes[u.Kind()] basic, ok := basicTypesN[u.Kind()]
if ok { if ok {
return basic.wireSize, ok return basic.wireSize, ok
} }
case *types.Struct: case *types.Struct:
for i := 0; i < u.NumFields(); i++ { for i := 0; i < u.NumFields(); i++ {
size, ok := typeSizeFixed(u.Field(i).Type()) size, ok := typeEncodingSizeFixed(encoding, u.Field(i).Type())
if !ok { if !ok {
goto notfixed goto notfixed
} }
...@@ -438,7 +515,7 @@ func typeSizeFixed(typ types.Type) (wireSize int, ok bool) { ...@@ -438,7 +515,7 @@ func typeSizeFixed(typ types.Type) (wireSize int, ok bool) {
return wireSize, true return wireSize, true
case *types.Array: case *types.Array:
elemSize, ok := typeSizeFixed(u.Elem()) elemSize, ok := typeEncodingSizeFixed(encoding, u.Elem())
if ok { if ok {
return int(u.Len()) * elemSize, ok return int(u.Len()) * elemSize, ok
} }
...@@ -449,17 +526,13 @@ notfixed: ...@@ -449,17 +526,13 @@ notfixed:
return 0, false return 0, false
} }
// does a type have fixed wire size == 1 ? // interface of a codegenerator (for sizer/encoder/decoder)
func typeSizeFixed1(typ types.Type) bool {
wireSize, _ := typeSizeFixed(typ)
return wireSize == 1
}
// interface of a codegenerator (for sizer/coder/decoder)
type CodeGenerator interface { type CodeGenerator interface {
// codegenerator generates code for this encoding
encoding() byte
// tell codegen it should generate code for which type & receiver name // tell codegen it should generate code for which type & receiver name
setFunc(recvName, typeName string, typ types.Type) setFunc(recvName, typeName string, typ types.Type, encoding byte)
// generate code to process a basic fixed type (not string) // generate code to process a basic fixed type (not string)
// userType is type actually used in source (for which typ is underlying), or nil // userType is type actually used in source (for which typ is underlying), or nil
...@@ -479,17 +552,36 @@ type CodeGenerator interface { ...@@ -479,17 +552,36 @@ type CodeGenerator interface {
genArray1(path string, typ *types.Array) genArray1(path string, typ *types.Array)
genSlice1(path string, typ types.Type) genSlice1(path string, typ types.Type)
// generate code to process header of struct
genStructHead(path string, typ *types.Struct, userType types.Type)
// mem.Buf // mem.Buf
genBuf(path string) genBuf(path string)
/*
// generate code for a custom type which implements its own // generate code for a custom type which implements its own
// encoding/decoding via implementing neo.customCodec interface. // encoding/decoding via implementing neo.customCodecN interface.
genCustom(path string) // XXX move out of common interface?
genCustomN(path string)
*/
// get generated code. // get generated code.
generatedCode() string generatedCode() string
} }
// interface for codegenerators to inject themselves into {sizer/encoder/decoder}Common.
type CodeGenCustomize interface {
CodeGenerator
// generate code to process slice or map header
genSliceHead(path string, typ *types.Slice, obj types.Object)
genMapHead(path string, typ *types.Map, obj types.Object)
}
// X reports encoding=X
type N struct{}; func (_ *N) encoding() byte { return 'N' }
type M struct{}; func (_ *M) encoding() byte { return 'M' }
// common part of codegenerators // common part of codegenerators
type commonCodeGen struct { type commonCodeGen struct {
buf Buffer // code is emitted here buf Buffer // code is emitted here
...@@ -497,6 +589,7 @@ type commonCodeGen struct { ...@@ -497,6 +589,7 @@ type commonCodeGen struct {
recvName string // receiver/type for top-level func recvName string // receiver/type for top-level func
typeName string // or empty typeName string // or empty
typ types.Type typ types.Type
enc byte // encoding variant
varUsed map[string]bool // whether a variable was used varUsed map[string]bool // whether a variable was used
} }
...@@ -505,10 +598,11 @@ func (c *commonCodeGen) emit(format string, a ...interface{}) { ...@@ -505,10 +598,11 @@ func (c *commonCodeGen) emit(format string, a ...interface{}) {
c.buf.emit(format, a...) c.buf.emit(format, a...)
} }
func (c *commonCodeGen) setFunc(recvName, typeName string, typ types.Type) { func (c *commonCodeGen) setFunc(recvName, typeName string, typ types.Type, encoding byte) {
c.recvName = recvName c.recvName = recvName
c.typeName = typeName c.typeName = typeName
c.typ = typ c.typ = typ
c.enc = encoding
} }
// get variable for varname (and automatically mark this var as used) // get variable for varname (and automatically mark this var as used)
...@@ -520,6 +614,12 @@ func (c *commonCodeGen) var_(varname string) string { ...@@ -520,6 +614,12 @@ func (c *commonCodeGen) var_(varname string) string {
return varname return varname
} }
// pathName returns name representing path or assignto.
func (c *commonCodeGen) pathName(path string) string {
// Type, p.f1.f2 -> Type.f1.f2
return strings.Join(append([]string{c.typeName}, strings.Split(path, ".")[1:]...), ".")
}
// symbolic size // symbolic size
// consists of numeric & symbolic expression parts // consists of numeric & symbolic expression parts
// size is num + expr1 + expr2 + ... // size is num + expr1 + expr2 + ...
...@@ -615,22 +715,24 @@ func (o *OverflowCheck) AddExpr(format string, a ...interface{}) { ...@@ -615,22 +715,24 @@ func (o *OverflowCheck) AddExpr(format string, a ...interface{}) {
} }
// sizer generates code to compute encoded size of a message // sizerX generates code to compute X-encoded size of a message.
// //
// when type is recursively walked, for every case symbolic size is added appropriately. // when type is recursively walked, for every case symbolic size is added appropriately.
// in case when it was needed to generate loops, runtime accumulator variable is additionally used. // in case when it was needed to generate loops, runtime accumulator variable is additionally used.
// result is: symbolic size + (optionally) runtime accumulator. // result is: symbolic size + (optionally) runtime accumulator.
type sizer struct { type sizerCommon struct {
commonCodeGen commonCodeGen
size SymSize // currently accumulated size size SymSize // currently accumulated size
} }
type sizerN struct { sizerCommon; N }
type sizerM struct { sizerCommon; M }
// encoder generates code to encode a message // encoderX generates code to X-encode a message.
// //
// when type is recursively walked, for every case code to update `data[n:]` is generated. // when type is recursively walked, for every case code to update `data[n:]` is generated.
// no overflow checks are generated as by neo.Msg interface provided data // no overflow checks are generated as by neo.Msg interface provided data
// buffer should have at least payloadLen length returned by NEOMsgEncodedLen() // buffer should have at least payloadLen length returned by neoMsgEncodedLenX()
// (the size computed by sizer). // (the size computed by sizerX).
// //
// the code emitted looks like: // the code emitted looks like:
// //
...@@ -638,14 +740,16 @@ type sizer struct { ...@@ -638,14 +740,16 @@ type sizer struct {
// encode<typ2>(data[n2:], path2) // encode<typ2>(data[n2:], path2)
// ... // ...
// //
// TODO encode have to care in NEOMsgEncode to emit preamble such that bound // TODO encode have to care in neoMsgEncodeX to emit preamble such that bound
// checking is performed only once (currently compiler emits many of them) // checking is performed only once (currently compiler emits many of them)
type encoder struct { type encoderCommon struct {
commonCodeGen commonCodeGen
n int // current write position in data n int // current write position in data
} }
type encoderN struct { encoderCommon; N }
type encoderM struct { encoderCommon; M }
// decoder generates code to decode a message // decoderX generates code to X-decode a message.
// //
// when type is recursively walked, for every case code to decode next item from // when type is recursively walked, for every case code to decode next item from
// `data[n:]` is generated. // `data[n:]` is generated.
...@@ -662,7 +766,7 @@ type encoder struct { ...@@ -662,7 +766,7 @@ type encoder struct {
// <assignto1> = decode<typ1>(data[n1:]) // <assignto1> = decode<typ1>(data[n1:])
// <assignto2> = decode<typ2>(data[n2:]) // <assignto2> = decode<typ2>(data[n2:])
// ... // ...
type decoder struct { type decoderCommon struct {
commonCodeGen commonCodeGen
// done buffer for generated code // done buffer for generated code
...@@ -677,16 +781,23 @@ type decoder struct { ...@@ -677,16 +781,23 @@ type decoder struct {
// current overflow check point // current overflow check point
overflow OverflowCheck overflow OverflowCheck
} }
type decoderN struct { decoderCommon; N }
type decoderM struct { decoderCommon; M }
var _ CodeGenerator = (*sizerN)(nil)
var _ CodeGenerator = (*encoderN)(nil)
var _ CodeGenerator = (*decoderN)(nil)
var _ CodeGenerator = (*sizer)(nil) var _ CodeGenerator = (*sizerM)(nil)
var _ CodeGenerator = (*encoder)(nil) var _ CodeGenerator = (*encoderM)(nil)
var _ CodeGenerator = (*decoder)(nil) var _ CodeGenerator = (*decoderM)(nil)
func (s *sizer) generatedCode() string { func (s *sizerCommon) generatedCode() string {
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOMsgEncodedLen() int {", s.recvName, s.typeName) code.emit("func (%s *%s) neoMsgEncodedLen%c() int {", s.recvName, s.typeName, s.enc)
if s.varUsed["size"] { if s.varUsed["size"] {
code.emit("var %s int", s.var_("size")) code.emit("var %s int", s.var_("size"))
} }
...@@ -704,10 +815,10 @@ func (s *sizer) generatedCode() string { ...@@ -704,10 +815,10 @@ func (s *sizer) generatedCode() string {
return code.String() return code.String()
} }
func (e *encoder) generatedCode() string { func (e *encoderCommon) generatedCode() string {
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOMsgEncode(data []byte) {", e.recvName, e.typeName) code.emit("func (%s *%s) neoMsgEncode%c(data []byte) {", e.recvName, e.typeName, e.enc)
code.Write(e.buf.Bytes()) code.Write(e.buf.Bytes())
...@@ -719,7 +830,7 @@ func (e *encoder) generatedCode() string { ...@@ -719,7 +830,7 @@ func (e *encoder) generatedCode() string {
// data = data[n:] // data = data[n:]
// n = 0 // n = 0
func (d *decoder) resetPos() { func (d *decoderCommon) resetPos() {
if d.n != 0 { if d.n != 0 {
d.emit("data = data[%v:]", d.n) d.emit("data = data[%v:]", d.n)
d.n = 0 d.n = 0
...@@ -743,7 +854,7 @@ func (d *decoder) resetPos() { ...@@ -743,7 +854,7 @@ func (d *decoder) resetPos() {
// - before reading a variable sized item // - before reading a variable sized item
// - in the beginning of a loop inside (via overflowCheckLoopEntry) // - in the beginning of a loop inside (via overflowCheckLoopEntry)
// - right after loop exit (via overflowCheckLoopExit) // - right after loop exit (via overflowCheckLoopExit)
func (d *decoder) overflowCheck() { func (d *decoderCommon) overflowCheck() {
// nop if we know overflow was already checked // nop if we know overflow was already checked
if d.overflow.checked { if d.overflow.checked {
return return
...@@ -781,7 +892,7 @@ func (d *decoder) overflowCheck() { ...@@ -781,7 +892,7 @@ func (d *decoder) overflowCheck() {
} }
// overflowCheck variant that should be inserted at the beginning of a loop inside // overflowCheck variant that should be inserted at the beginning of a loop inside
func (d *decoder) overflowCheckLoopEntry() { func (d *decoderCommon) overflowCheckLoopEntry() {
if d.overflow.checked { if d.overflow.checked {
return return
} }
...@@ -795,7 +906,7 @@ func (d *decoder) overflowCheckLoopEntry() { ...@@ -795,7 +906,7 @@ func (d *decoder) overflowCheckLoopEntry() {
} }
// overflowCheck variant that should be inserted right after loop exit // overflowCheck variant that should be inserted right after loop exit
func (d *decoder) overflowCheckLoopExit(loopLenExpr string) { func (d *decoderCommon) overflowCheckLoopExit(loopLenExpr string) {
if d.overflow.checked { if d.overflow.checked {
return return
} }
...@@ -813,13 +924,13 @@ func (d *decoder) overflowCheckLoopExit(loopLenExpr string) { ...@@ -813,13 +924,13 @@ func (d *decoder) overflowCheckLoopExit(loopLenExpr string) {
func (d *decoder) generatedCode() string { func (d *decoderCommon) generatedCode() string {
// flush for last overflow check point // flush for last overflow check point
d.overflowCheck() d.overflowCheck()
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOMsgDecode(data []byte) (int, error) {", d.recvName, d.typeName) code.emit("func (%s *%s) neoMsgDecode%c(data []byte) (int, error) {", d.recvName, d.typeName, d.enc)
if d.varUsed["nread"] { if d.varUsed["nread"] {
code.emit("var %v uint64", d.var_("nread")) code.emit("var %v uint64", d.var_("nread"))
} }
...@@ -827,6 +938,7 @@ func (d *decoder) generatedCode() string { ...@@ -827,6 +938,7 @@ func (d *decoder) generatedCode() string {
code.Write(d.bufDone.Bytes()) code.Write(d.bufDone.Bytes())
// epilogue // epilogue
// XXX M: return `n + (len0 - len(data))` without nread updates after every decode
retexpr := fmt.Sprintf("%v", d.nread) retexpr := fmt.Sprintf("%v", d.nread)
if d.varUsed["nread"] { if d.varUsed["nread"] {
// casting nread to int is ok even on 32 bit arches: // casting nread to int is ok even on 32 bit arches:
...@@ -839,7 +951,7 @@ func (d *decoder) generatedCode() string { ...@@ -839,7 +951,7 @@ func (d *decoder) generatedCode() string {
// `goto overflow` is not used only for empty structs // `goto overflow` is not used only for empty structs
// NOTE for >0 check actual X in StdSizes{X} does not particularly matter // NOTE for >0 check actual X in StdSizes{X} does not particularly matter
if (&types.StdSizes{8, 8}).Sizeof(d.typ) > 0 { if (&types.StdSizes{8, 8}).Sizeof(d.typ) > 0 || d.enc != 'N' {
code.emit("\noverflow:") code.emit("\noverflow:")
code.emit("return 0, ErrDecodeOverflow") code.emit("return 0, ErrDecodeOverflow")
} }
...@@ -848,14 +960,16 @@ func (d *decoder) generatedCode() string { ...@@ -848,14 +960,16 @@ func (d *decoder) generatedCode() string {
return code.String() return code.String()
} }
// emit code to size/encode/decode basic fixed type // ---- basic types ----
func (s *sizer) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypes[typ.Kind()] // N: emit code to size/encode/decode basic fixed type
func (s *sizerN) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypesN[typ.Kind()]
s.size.Add(basic.wireSize) s.size.Add(basic.wireSize)
} }
func (e *encoder) genBasic(path string, typ *types.Basic, userType types.Type) { func (e *encoderN) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypes[typ.Kind()] basic := basicTypesN[typ.Kind()]
dataptr := fmt.Sprintf("data[%v:]", e.n) dataptr := fmt.Sprintf("data[%v:]", e.n)
if userType != typ && userType != nil { if userType != typ && userType != nil {
// userType is a named type over some basic, like // userType is a named type over some basic, like
...@@ -867,8 +981,8 @@ func (e *encoder) genBasic(path string, typ *types.Basic, userType types.Type) { ...@@ -867,8 +981,8 @@ func (e *encoder) genBasic(path string, typ *types.Basic, userType types.Type) {
e.n += basic.wireSize e.n += basic.wireSize
} }
func (d *decoder) genBasic(assignto string, typ *types.Basic, userType types.Type) { func (d *decoderN) genBasic(assignto string, typ *types.Basic, userType types.Type) {
basic := basicTypes[typ.Kind()] basic := basicTypesN[typ.Kind()]
// XXX specifying :hi is not needed - it is only a workaround to help BCE. // XXX specifying :hi is not needed - it is only a workaround to help BCE.
// see https://github.com/golang/go/issues/19126#issuecomment-358743715 // see https://github.com/golang/go/issues/19126#issuecomment-358743715
...@@ -887,33 +1001,311 @@ func (d *decoder) genBasic(assignto string, typ *types.Basic, userType types.Typ ...@@ -887,33 +1001,311 @@ func (d *decoder) genBasic(assignto string, typ *types.Basic, userType types.Typ
d.overflow.Add(basic.wireSize) d.overflow.Add(basic.wireSize)
} }
// M: XXX
func (s *sizerM) genBasic(path string, typ *types.Basic, userType types.Type) {
// upath casts path into basic type if needed
// e.g. p.x -> int32(p.x) if p.x is custom type with underlying int32
upath := path
if userType.Underlying() != userType {
upath = fmt.Sprintf("%s(%s)", typ.Name(), upath)
}
// zodb.Tid and zodb.Oid are encoded as [8]bin XXX or nil?
if userType == zodbTid || userType == zodbOid {
s.size.Add(1+1+8) // mbin8 + 8 + [8]data
return
}
// enums are encoded as extensions
if _, isEnum := enumRegistry[userType]; isEnum {
s.size.Add(1+1+1) // fixext1 enumType value
return
}
switch typ.Kind() {
case types.Bool: s.size.Add(1) // mfalse|mtrue
case types.Int8: s.size.AddExpr("msgpack.Int8Size(%s)", upath)
case types.Int16: s.size.AddExpr("msgpack.Int16Size(%s)", upath)
case types.Int32: s.size.AddExpr("msgpack.Int32Size(%s)", upath)
case types.Int64: s.size.AddExpr("msgpack.Int64Size(%s)", upath)
case types.Uint8: s.size.AddExpr("msgpack.Uint8Size(%s)", upath)
case types.Uint16: s.size.AddExpr("msgpack.Uint16Size(%s)", upath)
case types.Uint32: s.size.AddExpr("msgpack.Uint32Size(%s)", upath)
case types.Uint64: s.size.AddExpr("msgpack.Uint64Size(%s)", upath)
case types.Float64: s.size.Add(1+8) // mfloat64 + <value64>
}
}
func (e *encoderM) genBasic(path string, typ *types.Basic, userType types.Type) {
// upath casts path into basic type if needed
// e.g. p.x -> int32(p.x) if p.x is custom type with underlying int32
// XXX dup?
upath := path
if userType.Underlying() != userType {
upath = fmt.Sprintf("%s(%s)", typ.Name(), upath)
}
// zodb.Tid and zodb.Oid are encoded as [8]bin XXX or nil ?
if userType == zodbTid || userType == zodbOid {
e.emit("data[%v] = byte(msgpack.Bin8)", e.n); e.n++
e.emit("data[%v] = 8", e.n); e.n++
e.emit("binary.BigEndian.PutUint64(data[%v:], uint64(%s))", e.n, path)
e.n += 8
return
}
// enums are encoded as `fixext1 enumType fixint<value>`
if enum, ok := enumRegistry[userType]; ok {
e.emit("data[%v] = byte(msgpack.FixExt1)", e.n); e.n++
e.emit("data[%v] = %d", e.n, enum); e.n++
e.emit("if !(0 <= %s && %s <= 0x7f) {", path, path) // mposfixint
e.emit(` panic("%s: invalid %s enum value)")`, path, typeName(userType))
e.emit("}")
e.emit("data[%v] = byte(%s)", e.n, path); e.n++
return
}
// mputint emits mput<kind>int<size>(path)
mputint := func(kind string, size int) {
KI := "I" // I or <Kind>i
if kind != "" {
KI = strings.ToUpper(kind) + "i"
}
e.emit("{")
e.emit("n := msgpack.Put%snt%d(data[%v:], %s)", KI, size, e.n, upath)
e.emit("data = data[%v+n:]", e.n)
e.emit("}")
e.n = 0
}
switch typ.Kind() {
case types.Bool:
e.emit("data[%v] = byte(msgpack.Bool(%s))", e.n, path)
e.n += 1
case types.Int8: mputint("", 8)
case types.Int16: mputint("", 16)
case types.Int32: mputint("", 32)
case types.Int64: mputint("", 64)
case types.Uint8: mputint("u", 8)
case types.Uint16: mputint("u", 16)
case types.Uint32: mputint("u", 32)
case types.Uint64: mputint("u", 64)
case types.Float64:
// mfloat64 f64
e.emit("data[%v] = byte(msgpack.Float64)", e.n); e.n++
e.emit("float64_neoEncode(data[%v:], %s)", e.n, upath); e.n += 8
}
}
// decoder expects <op>
// XXX place
func (d *decoderM) expectOp(assignto string, op string) {
d.emit("if op := msgpack.Op(data[%v]); op != %s {", d.n, op); d.n++
d.emit(" return 0, mdecodeOpErr(%q, op, %s)", d.pathName(assignto), op)
d.emit("}")
d.overflow.Add(1)
}
// decoder expects mbin8 l
func (d *decoderM) expectBin8Fix(assignto string, l int) {
d.expectOp(assignto, "msgpack.Bin8")
d.emit("if l := data[%v]; l != %d {", d.n, l); d.n++
d.emit(" return 0, mdecodeLen8Err(%q, l, %d)", d.pathName(assignto), l)
d.emit("}")
d.overflow.Add(1)
}
// decoder expects mfixext1 <enumType>
func (d *decoderM) expectEnum(assignto string, enumType int) {
d.expectOp(assignto, "msgpack.FixExt1")
d.emit("if enumType := data[%v]; enumType != %d {", d.n, enumType); d.n++
d.emit(" return 0, mdecodeEnumTypeErr(%q, enumType, %d)", d.pathName(assignto), enumType)
d.emit("}")
d.overflow.Add(1)
}
func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Type) {
// zodb.Tid and zodb.Oid are encoded as [8]bin
if userType == zodbTid || userType == zodbOid {
d.expectBin8Fix(assignto, 8)
d.emit("%s= %s(binary.BigEndian.Uint64(data[%v:]))", assignto, typeName(userType), d.n)
d.n += 8
d.overflow.Add(8)
return
}
// enums are encoded as `fixext1 enumType fixint<value>`
if enum, ok := enumRegistry[userType]; ok {
d.expectEnum(assignto, enum)
d.emit("{")
d.emit("v := data[%v]", d.n); d.n++
d.emit("if !(0 <= v && v <= 0x7f) {") // mposfixint
d.emit(" return 0, mdecodeEnumValueErr(%q, v)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s(v)", assignto, typeName(userType))
d.emit("}")
d.overflow.Add(1)
return
}
// v represents basic decoded value casted to user type if needed
v := "v"
if userType.Underlying() != userType {
v = fmt.Sprintf("%s(v)", typeName(userType))
}
// mgetint emits assignto = mget<kind>int<size>()
mgetint := func(kind string, size int) {
// we are going to go into msgp - flush previously queued
// overflow checks; put place for next overflow check after
// msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
KI := "I" // I or <Kind>i
if kind != "" {
KI = strings.ToUpper(kind) + "i"
}
d.emit("{")
d.emit("v, tail, err := msgp.Read%snt%dBytes(data)", KI, size)
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
// mgetfloat emits mgetfloat<size>
mgetfloat := func(size int) {
// delving into msgp - flush/prepare next site for overflow check
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size)
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
switch typ.Kind() {
case types.Bool:
// XXX move -> mgetbool ?
d.emit("switch op := msgpack.Op(data[%v]); op {", d.n)
// XXX vvv False also ok
d.emit("default: return 0, mdecodeOpErr(%q, op, msgpack.True)", d.pathName(assignto))
d.emit("case msgpack.True: %s = true", assignto)
d.emit("case msgpack.False: %s = false", assignto)
// XXX also support 0|1 ?
d.emit("}")
d.n++
d.overflow.Add(1)
/*
// -> msgp - flush queued overflow checks; put place for next
// overflow checks after msgp is done.
// XXX better directly compare against mtrue|mfalse|0|1 ?
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("v, tail, err := msgp.ReadBoolBytes(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
*/
case types.Int8: mgetint("", 8)
case types.Int16: mgetint("", 16)
case types.Int32: mgetint("", 32)
case types.Int64: mgetint("", 64)
case types.Uint8: mgetint("u", 8)
case types.Uint16: mgetint("u", 16)
case types.Uint32: mgetint("u", 32)
case types.Uint64: mgetint("u", 64)
case types.Float64: mgetfloat(64)
}
}
// emit code to size/encode/decode array with sizeof(elem)==1 // emit code to size/encode/decode array with sizeof(elem)==1
// [len(A)]byte // [len(A)]byte
func (s *sizer) genArray1(path string, typ *types.Array) { func (s *sizerN) genArray1(path string, typ *types.Array) {
s.size.Add(int(typ.Len())) s.size.Add(int(typ.Len()))
} }
func (e *encoder) genArray1(path string, typ *types.Array) { func (e *encoderN) genArray1(path string, typ *types.Array) {
e.emit("copy(data[%v:], %v[:])", e.n, path) e.emit("copy(data[%v:], %v[:])", e.n, path)
e.n += int(typ.Len()) e.n += int(typ.Len())
} }
func (d *decoder) genArray1(assignto string, typ *types.Array) { func (d *decoderN) genArray1(assignto string, typ *types.Array) {
typLen := int(typ.Len()) typLen := int(typ.Len())
d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+typLen) d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+typLen)
d.n += typLen d.n += typLen
d.overflow.Add(typLen) d.overflow.Add(typLen)
} }
// binX+lenX
// [len(A)]byte
func (s *sizerM) genArray1(path string, typ *types.Array) {
l := int(typ.Len())
s.size.Add(msgpack.BinHeadSize(l))
s.size.Add(l)
}
func (e *encoderM) genArray1(path string, typ *types.Array) {
l := int(typ.Len())
if l > 0xff {
panic("TODO: array1 with > 255 elements")
}
e.emit("data[%v] = byte(msgpack.Bin8)", e.n); e.n++
e.emit("data[%v] = %d", e.n, l); e.n++
e.emit("copy(data[%v:], %v[:])", e.n, path)
e.n += l
}
func (d *decoderM) genArray1(assignto string, typ *types.Array) {
l := int(typ.Len())
if l > 0xff {
panic("TODO: array1 with > 255 elements")
}
d.expectBin8Fix(assignto, l)
d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+l)
d.n += l
d.overflow.Add(l)
}
// emit code to size/encode/decode string or []byte // emit code to size/encode/decode string or []byte
// len u32 // len u32
// [len]byte // [len]byte
func (s *sizer) genSlice1(path string, typ types.Type) { func (s *sizerN) genSlice1(path string, typ types.Type) {
s.size.Add(4) s.size.Add(4)
s.size.AddExpr("len(%s)", path) s.size.AddExpr("len(%s)", path)
} }
func (e *encoder) genSlice1(path string, typ types.Type) { func (e *encoderN) genSlice1(path string, typ types.Type) {
e.emit("{") e.emit("{")
e.emit("l := uint32(len(%s))", path) e.emit("l := uint32(len(%s))", path)
e.genBasic("l", types.Typ[types.Uint32], nil) e.genBasic("l", types.Typ[types.Uint32], nil)
...@@ -924,7 +1316,7 @@ func (e *encoder) genSlice1(path string, typ types.Type) { ...@@ -924,7 +1316,7 @@ func (e *encoder) genSlice1(path string, typ types.Type) {
e.n = 0 e.n = 0
} }
func (d *decoder) genSlice1(assignto string, typ types.Type) { func (d *decoderN) genSlice1(assignto string, typ types.Type) {
d.emit("{") d.emit("{")
d.genBasic("l:", types.Typ[types.Uint32], nil) d.genBasic("l:", types.Typ[types.Uint32], nil)
...@@ -953,17 +1345,75 @@ func (d *decoder) genSlice1(assignto string, typ types.Type) { ...@@ -953,17 +1345,75 @@ func (d *decoder) genSlice1(assignto string, typ types.Type) {
d.emit("}") d.emit("}")
} }
// bin8+len8|bin16+len16|bin32+len32
// [len]byte
func (s *sizerM) genSlice1(path string, typ types.Type) {
// XXX -> mbinsize(len(path)) ?
s.size.AddExpr("msgpack.BinHeadSize(len(%s))", path)
s.size.AddExpr("len(%s)", path)
}
func (e *encoderM) genSlice1(path string, typ types.Type) {
e.emit("{")
e.emit("l := len(%s)", path)
e.emit("n := msgpack.PutBinHead(data[%v:], l)", e.n)
e.emit("data = data[%v+n:]", e.n)
e.emit("copy(data, %v)", path)
e.emit("data = data[l:]")
e.emit("}")
e.n = 0
}
func (d *decoderM) genSlice1(assignto string, typ types.Type) {
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("b, tail, err := msgp.ReadBytesZC(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
// XXX dup wrt decoderN ?
switch t := typ.(type) {
case *types.Basic:
if t.Kind() != types.String {
log.Panicf("bad basic type in slice1: %v", t)
}
d.emit("%v= string(b)", assignto)
case *types.Slice:
// TODO eventually do not copy, but reference data from original
d.emit("%v= make(%v, len(b))", assignto, typeName(typ))
d.emit("copy(%v, b)", assignto)
default:
log.Panicf("bad type in slice1: %v", typ)
}
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
// emit code to size/encode/decode mem.Buf // emit code to size/encode/decode mem.Buf
// same as slice1 but buffer is allocated via mem.BufAlloc // same as slice1 but buffer is allocated via mem.BufAlloc
func (s *sizer) genBuf(path string) { func (s *sizerN) genBuf(path string) {
s.genSlice1(path+".XData()", nil /* typ unused */)
}
func (s *sizerM) genBuf(path string) {
s.genSlice1(path+".XData()", nil /* typ unused */) s.genSlice1(path+".XData()", nil /* typ unused */)
} }
func (e *encoder) genBuf(path string) { func (e *encoderN) genBuf(path string) {
e.genSlice1(path+".XData()", nil /* typ unused */)
}
func (e *encoderM) genBuf(path string) {
e.genSlice1(path+".XData()", nil /* typ unused */) e.genSlice1(path+".XData()", nil /* typ unused */)
} }
func (d *decoder) genBuf(path string) { func (d *decoderN) genBuf(assignto string) {
d.emit("{") d.emit("{")
d.genBasic("l:", types.Typ[types.Uint32], nil) d.genBasic("l:", types.Typ[types.Uint32], nil)
...@@ -973,21 +1423,54 @@ func (d *decoder) genBuf(path string) { ...@@ -973,21 +1423,54 @@ func (d *decoder) genBuf(path string) {
d.overflow.AddExpr("uint64(l)") d.overflow.AddExpr("uint64(l)")
// TODO eventually do not copy but reference original // TODO eventually do not copy but reference original
d.emit("%v= mem.BufAlloc(int(l))", path) d.emit("%v= mem.BufAlloc(int(l))", assignto)
d.emit("copy(%v.Data, data[:l])", path) d.emit("copy(%v.Data, data[:l])", assignto)
d.emit("data = data[l:]") d.emit("data = data[l:]")
d.emit("}") d.emit("}")
} }
func (d *decoderM) genBuf(assignto string) {
// XXX dup wrt decoderM.genSlice1 ?
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("b, tail, err := msgp.ReadBytesZC(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
// XXX dup wrt decoderN.genBuf ?
// TODO eventually do not copy but reference original
d.emit("%v= mem.BufAlloc(len(b))", assignto)
d.emit("copy(%v.Data, b)", assignto)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
// emit code to size/encode/decode slice // emit code to size/encode/decode slice
// len u32 // len u32
// [len]item // [len]item
func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) { func (s *sizerN) genSliceHead(path string, typ *types.Slice, obj types.Object) {
s.size.Add(4) s.size.Add(4)
}
func (s *sizerN) genSlice(path string, typ *types.Slice, obj types.Object) {
s.genSliceCommon(s, path, typ, obj)
}
func (s *sizerCommon) genSliceCommon(xs CodeGenCustomize, path string, typ *types.Slice, obj types.Object) {
xs.genSliceHead(path, typ, obj)
// if size(item)==const - size update in one go // if size(item)==const - size update in one go
elemSize, ok := typeSizeFixed(typ.Elem()) elemSize, ok := typeEncodingSizeFixed(xs.encoding(), typ.Elem())
if ok { if ok {
s.size.AddExpr("len(%v) * %v", path, elemSize) s.size.AddExpr("len(%v) * %v", path, elemSize)
return return
...@@ -999,7 +1482,7 @@ func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) { ...@@ -999,7 +1482,7 @@ func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) {
s.emit("for i := 0; i < len(%v); i++ {", path) s.emit("for i := 0; i < len(%v); i++ {", path)
s.emit("a := &%s[i]", path) s.emit("a := &%s[i]", path)
codegenType("(*a)", typ.Elem(), obj, s) codegenType("(*a)", typ.Elem(), obj, xs)
// merge-in size updates // merge-in size updates
s.emit("%v += %v", s.var_("size"), s.size.ExprString()) s.emit("%v += %v", s.var_("size"), s.size.ExprString())
...@@ -1010,29 +1493,45 @@ func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) { ...@@ -1010,29 +1493,45 @@ func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) {
s.size = curSize s.size = curSize
} }
func (e *encoder) genSlice(path string, typ *types.Slice, obj types.Object) { func (e *encoderN) genSliceHead(path string, typ *types.Slice, obj types.Object) {
e.emit("{") e.emit("l := len(%s)", path)
e.emit("l := uint32(len(%s))", path) e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int])
e.genBasic("l", types.Typ[types.Uint32], nil)
e.emit("data = data[%v:]", e.n) e.emit("data = data[%v:]", e.n)
e.n = 0 e.n = 0
e.emit("for i := 0; uint32(i) <l; i++ {") }
func (e *encoderN) genSlice(path string, typ *types.Slice, obj types.Object) {
e.genSliceCommon(e, path, typ, obj)
}
func (e *encoderCommon) genSliceCommon(xe CodeGenCustomize, path string, typ *types.Slice, obj types.Object) {
e.emit("{")
xe.genSliceHead(path, typ, obj)
e.emit("for i := 0; i <l; i++ {")
e.emit("a := &%s[i]", path) e.emit("a := &%s[i]", path)
codegenType("(*a)", typ.Elem(), obj, e) codegenType("(*a)", typ.Elem(), obj, xe)
e.emit("data = data[%v:]", e.n) if e.n != 0 {
e.emit("data = data[%v:]", e.n)
e.n = 0
}
e.emit("}") e.emit("}")
e.emit("}") e.emit("}")
e.n = 0
} }
func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object) { func (d *decoderN) genSliceHead(assignto string, typ *types.Slice, obj types.Object) {
d.emit("{")
d.genBasic("l:", types.Typ[types.Uint32], nil) d.genBasic("l:", types.Typ[types.Uint32], nil)
}
func (d *decoderN) genSlice(assignto string, typ *types.Slice, obj types.Object) {
d.genSliceCommon(d, assignto, typ, obj)
}
func (d *decoderCommon) genSliceCommon(xd CodeGenCustomize, assignto string, typ *types.Slice, obj types.Object) {
d.emit("{")
xd.genSliceHead(assignto, typ, obj)
d.resetPos() d.resetPos()
// if size(item)==const - check overflow in one go // if size(item)==const - check overflow in one go
elemSize, elemFixed := typeSizeFixed(typ.Elem()) elemSize, elemFixed := typeDecodingSizeFixed(xd.encoding(), typ.Elem())
if elemFixed { if elemFixed {
d.overflowCheck() d.overflowCheck()
d.overflow.AddExpr("uint64(l) * %v", elemSize) d.overflow.AddExpr("uint64(l) * %v", elemSize)
...@@ -1045,7 +1544,7 @@ func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object) ...@@ -1045,7 +1544,7 @@ func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object)
d.emit("a := &%s[i]", assignto) d.emit("a := &%s[i]", assignto)
d.overflowCheckLoopEntry() d.overflowCheckLoopEntry()
codegenType("(*a)", typ.Elem(), obj, d) codegenType("(*a)", typ.Elem(), obj, xd)
d.resetPos() d.resetPos()
d.emit("}") d.emit("}")
...@@ -1054,27 +1553,72 @@ func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object) ...@@ -1054,27 +1553,72 @@ func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object)
d.emit("}") d.emit("}")
} }
// fixarray|array16+YYYY|array32+ZZZZ
// [len]item
func (s *sizerM) genSliceHead(path string, typ *types.Slice, obj types.Object) {
s.size.AddExpr("msgpack.ArrayHeadSize(len(%s))", path)
}
func (s *sizerM) genSlice(path string, typ *types.Slice, obj types.Object) {
s.genSliceCommon(s, path, typ, obj)
}
func (e *encoderM) genSliceHead(path string, typ *types.Slice, obj types.Object) {
e.emit("l := len(%s)", path)
e.emit("n := msgpack.PutArrayHead(data[%v:], l)", e.n)
e.emit("data = data[%v+n:]", e.n)
e.n = 0
}
func (e *encoderM) genSlice(path string, typ *types.Slice, obj types.Object) {
e.genSliceCommon(e, path, typ, obj)
}
func (d *decoderM) genSliceHead(assignto string, typ *types.Slice, obj types.Object) {
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("l, tail, err := msgp.ReadArrayHeaderBytes(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
}
func (d *decoderM) genSlice(assignto string, typ *types.Slice, obj types.Object) {
d.genSliceCommon(d, assignto, typ, obj)
}
// generate code to encode/decode map // generate code to encode/decode map
// len u32 // len u32
// [len](key, value) // [len](key, value)
func (s *sizer) genMap(path string, typ *types.Map, obj types.Object) { func (s *sizerN) genMapHead(path string, typ *types.Map, obj types.Object) {
keySize, keyFixed := typeSizeFixed(typ.Key()) s.size.Add(4)
elemSize, elemFixed := typeSizeFixed(typ.Elem()) }
func (s *sizerN) genMap(path string, typ *types.Map, obj types.Object) {
s.genMapCommon(s, path, typ, obj)
}
func (s *sizerCommon) genMapCommon(xs CodeGenCustomize, path string, typ *types.Map, obj types.Object) {
xs.genMapHead(path, typ, obj)
keySize, keyFixed := typeEncodingSizeFixed(xs.encoding(), typ.Key())
elemSize, elemFixed := typeEncodingSizeFixed(xs.encoding(), typ.Elem())
if keyFixed && elemFixed { if keyFixed && elemFixed {
s.size.Add(4)
s.size.AddExpr("len(%v) * %v", path, keySize+elemSize) s.size.AddExpr("len(%v) * %v", path, keySize+elemSize)
return return
} }
s.size.Add(4)
curSize := s.size curSize := s.size
s.size.Reset() s.size.Reset()
// FIXME for map of map gives ...[key][key] => key -> different variables // FIXME for map of map gives ...[key][key] => key -> different variables
s.emit("for key := range %s {", path) s.emit("for key := range %s {", path)
codegenType("key", typ.Key(), obj, s) codegenType("key", typ.Key(), obj, xs)
codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, s) codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, xs)
// merge-in size updates // merge-in size updates
s.emit("%v += %v", s.var_("size"), s.size.ExprString()) s.emit("%v += %v", s.var_("size"), s.size.ExprString())
...@@ -1085,12 +1629,19 @@ func (s *sizer) genMap(path string, typ *types.Map, obj types.Object) { ...@@ -1085,12 +1629,19 @@ func (s *sizer) genMap(path string, typ *types.Map, obj types.Object) {
s.size = curSize s.size = curSize
} }
func (e *encoder) genMap(path string, typ *types.Map, obj types.Object) { func (e *encoderN) genMapHead(path string, typ *types.Map, obj types.Object) {
e.emit("{") e.emit("l := len(%s)", path)
e.emit("l := uint32(len(%s))", path) e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int])
e.genBasic("l", types.Typ[types.Uint32], nil)
e.emit("data = data[%v:]", e.n) e.emit("data = data[%v:]", e.n)
e.n = 0 e.n = 0
}
func (e *encoderN) genMap(path string, typ *types.Map, obj types.Object) {
e.genMapCommon(e, path, typ, obj)
}
func (e *encoderCommon) genMapCommon(xe CodeGenCustomize, path string, typ *types.Map, obj types.Object) {
e.emit("{")
xe.genMapHead(path, typ, obj)
// output keys in sorted order on the wire // output keys in sorted order on the wire
// (easier for debugging & deterministic for testing) // (easier for debugging & deterministic for testing)
...@@ -1101,23 +1652,32 @@ func (e *encoder) genMap(path string, typ *types.Map, obj types.Object) { ...@@ -1101,23 +1652,32 @@ func (e *encoder) genMap(path string, typ *types.Map, obj types.Object) {
e.emit("sort.Slice(keyv, func (i, j int) bool { return keyv[i] < keyv[j] })") e.emit("sort.Slice(keyv, func (i, j int) bool { return keyv[i] < keyv[j] })")
e.emit("for _, key := range keyv {") e.emit("for _, key := range keyv {")
codegenType("key", typ.Key(), obj, e) codegenType("key", typ.Key(), obj, xe)
codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, e) codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, xe)
e.emit("data = data[%v:]", e.n) // XXX wrt map of map? if e.n != 0 {
e.emit("data = data[%v:]", e.n) // XXX wrt map of map?
e.n = 0
}
e.emit("}") e.emit("}")
e.emit("}") e.emit("}")
e.n = 0
} }
func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) { func (d *decoderN) genMapHead(assignto string, typ *types.Map, obj types.Object) {
d.emit("{")
d.genBasic("l:", types.Typ[types.Uint32], nil) d.genBasic("l:", types.Typ[types.Uint32], nil)
}
func (d *decoderN) genMap(assignto string, typ *types.Map, obj types.Object) {
d.genMapCommon(d, assignto, typ, obj)
}
func (d *decoderCommon) genMapCommon(xd CodeGenCustomize, assignto string, typ *types.Map, obj types.Object) {
d.emit("{")
xd.genMapHead(assignto, typ, obj)
d.resetPos() d.resetPos()
// if size(key,item)==const - check overflow in one go // if size(key,item)==const - check overflow in one go
keySize, keyFixed := typeSizeFixed(typ.Key()) keySize, keyFixed := typeDecodingSizeFixed(xd.encoding(), typ.Key())
elemSize, elemFixed := typeSizeFixed(typ.Elem()) elemSize, elemFixed := typeDecodingSizeFixed(xd.encoding(), typ.Elem())
if keyFixed && elemFixed { if keyFixed && elemFixed {
d.overflowCheck() d.overflowCheck()
d.overflow.AddExpr("uint64(l) * %v", keySize+elemSize) d.overflow.AddExpr("uint64(l) * %v", keySize+elemSize)
...@@ -1130,18 +1690,19 @@ func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) { ...@@ -1130,18 +1690,19 @@ func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) {
d.emit("for i := 0; uint32(i) < l; i++ {") d.emit("for i := 0; uint32(i) < l; i++ {")
d.overflowCheckLoopEntry() d.overflowCheckLoopEntry()
codegenType("key:", typ.Key(), obj, d) d.emit("var key %s", typeName(typ.Key()))
codegenType("key", typ.Key(), obj, xd)
switch typ.Elem().Underlying().(type) { switch typ.Elem().Underlying().(type) {
// basic types can be directly assigned to map entry // basic types can be directly assigned to map entry
case *types.Basic: case *types.Basic:
codegenType("m[key]", typ.Elem(), obj, d) codegenType("m[key]", typ.Elem(), obj, xd)
// otherwise assign via temporary // otherwise assign via temporary
default: default:
d.emit("var v %v", typeName(typ.Elem())) d.emit("var mv %v", typeName(typ.Elem()))
codegenType("v", typ.Elem(), obj, d) codegenType("mv", typ.Elem(), obj, xd)
d.emit("m[key] = v") d.emit("m[key] = mv")
} }
d.resetPos() d.resetPos()
...@@ -1151,27 +1712,65 @@ func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) { ...@@ -1151,27 +1712,65 @@ func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) {
d.emit("}") d.emit("}")
} }
// fixmap|map16+YYYY|map32+ZZZZ
// [len]key/value
func (s *sizerM) genMapHead(path string, typ *types.Map, obj types.Object) {
s.size.AddExpr("msgpack.MapHeadSize(len(%s))", path)
}
func (s *sizerM) genMap(path string, typ *types.Map, obj types.Object) {
s.genMapCommon(s, path, typ, obj)
}
func (e *encoderM) genMapHead(path string, typ *types.Map, obj types.Object) {
e.emit("l := len(%s)", path)
e.emit("n := msgpack.PutMapHead(data[%v:], l)", e.n)
e.emit("data = data[%v+n:]", e.n)
e.n = 0
}
func (e *encoderM) genMap(path string, typ *types.Map, obj types.Object) {
e.genMapCommon(e, path, typ, obj)
}
func (d *decoderM) genMapHead(assignto string, typ *types.Map, obj types.Object) {
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("l, tail, err := msgp.ReadMapHeaderBytes(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
}
func (d *decoderM) genMap(assignto string, typ *types.Map, obj types.Object) {
d.genMapCommon(d, assignto, typ, obj)
}
// emit code to size/encode/decode custom type // emit code to size/encode/decode custom type
func (s *sizer) genCustom(path string) { func (s *sizerN) genCustomN(path string) {
s.size.AddExpr("%s.neoEncodedLen()", path) s.size.AddExpr("%s.neoEncodedLenN()", path)
} }
func (e *encoder) genCustom(path string) { func (e *encoderN) genCustomN(path string) {
e.emit("{") e.emit("{")
e.emit("n := %s.neoEncode(data[%v:])", path, e.n) e.emit("n := %s.neoEncodeN(data[%v:])", path, e.n)
e.emit("data = data[%v + n:]", e.n) e.emit("data = data[%v + n:]", e.n)
e.emit("}") e.emit("}")
e.n = 0 e.n = 0
} }
func (d *decoder) genCustom(path string) { func (d *decoderN) genCustomN(path string) {
d.resetPos() d.resetPos()
// make sure we check for overflow previous-code before proceeding to custom decoder. // make sure we check for overflow previous-code before proceeding to custom decoder.
d.overflowCheck() d.overflowCheck()
d.emit("{") d.emit("{")
d.emit("n, ok := %s.neoDecode(data)", path) d.emit("n, ok := %s.neoDecodeN(data)", path)
d.emit("if !ok { goto overflow }") d.emit("if !ok { goto overflow }")
d.emit("data = data[n:]") d.emit("data = data[n:]")
d.emit("%v += n", d.var_("nread")) d.emit("%v += n", d.var_("nread"))
...@@ -1182,15 +1781,53 @@ func (d *decoder) genCustom(path string) { ...@@ -1182,15 +1781,53 @@ func (d *decoder) genCustom(path string) {
d.overflowCheck() d.overflowCheck()
} }
// ---- struct head ----
// N: nothing
func (s *sizerN) genStructHead(path string, typ *types.Struct, userType types.Type) {}
func (e *encoderN) genStructHead(path string, typ *types.Struct, userType types.Type) {}
func (d *decoderN) genStructHead(path string, typ *types.Struct, userType types.Type) {}
// M: array<nfields>
func (s *sizerM) genStructHead(path string, typ *types.Struct, userType types.Type) {
s.size.Add(1) // mfixarray|marray16|marray32
if typ.NumFields() > 0x0f {
panic("TODO: struct with > 15 elements")
}
}
func (e *encoderM) genStructHead(path string, typ *types.Struct, userType types.Type) {
if typ.NumFields() > 0x0f {
panic("TODO: struct with > 15 elements")
}
e.emit("data[%v] = byte(msgpack.FixArray_4 | %d)", e.n, typ.NumFields())
e.n += 1
}
func (d *decoderM) genStructHead(path string, typ *types.Struct, userType types.Type) {
if typ.NumFields() > 0x0f {
panic("TODO: struct with > 15 elements")
}
d.emit("if op, opOk := msgpack.Op(data[%v]), msgpack.FixArray_4 | %d ; op != opOk {", d.n, typ.NumFields())
d.emit("return 0, &mstructDecodeError{%q, op, opOk}", d.pathName(path))
d.emit("}")
d.n += 1
d.overflow.Add(1)
}
// top-level driver for emitting size/encode/decode code for a type // top-level driver for emitting size/encode/decode code for a type
// //
// obj is object that uses this type in source program (so in case of an error // obj is object that uses this type in source program (so in case of an error
// we can point to source location for where it happened) // we can point to source location for where it happened)
func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGenerator) { func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGenerator) {
// neo.customCodec // neo.customCodecN
if types.Implements(typ, neo_customCodec) || ccCustomN, ok := codegen.(interface { genCustomN(path string) })
types.Implements(types.NewPointer(typ), neo_customCodec) { if ok && (types.Implements(typ, neo_customCodecN) ||
codegen.genCustom(path) types.Implements(types.NewPointer(typ), neo_customCodecN)) {
ccCustomN.genCustomN(path)
return return
} }
...@@ -1208,13 +1845,14 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene ...@@ -1208,13 +1845,14 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
break break
} }
_, ok := basicTypes[u.Kind()] _, ok := basicTypesN[u.Kind()] // ok to check N to see if supported for both N and M
if !ok { if !ok {
log.Fatalf("%v: %v: basic type %v not supported", pos(obj), obj.Name(), u) log.Fatalf("%v: %v: basic type %v not supported", pos(obj), obj.Name(), u)
} }
codegen.genBasic(path, u, typ) codegen.genBasic(path, u, typ)
case *types.Struct: case *types.Struct:
codegen.genStructHead(path, u, typ)
for i := 0; i < u.NumFields(); i++ { for i := 0; i < u.NumFields(); i++ {
v := u.Field(i) v := u.Field(i)
codegenType(path+"."+v.Name(), v.Type(), v, codegen) codegenType(path+"."+v.Name(), v.Type(), v, codegen)
...@@ -1222,7 +1860,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene ...@@ -1222,7 +1860,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
case *types.Array: case *types.Array:
// [...]byte or [...]uint8 - just straight copy // [...]byte or [...]uint8 - just straight copy
if typeSizeFixed1(u.Elem()) { if isByte(u.Elem()) {
codegen.genArray1(path, u) codegen.genArray1(path, u)
} else { } else {
var i int64 var i int64
...@@ -1232,7 +1870,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene ...@@ -1232,7 +1870,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
} }
case *types.Slice: case *types.Slice:
if typeSizeFixed1(u.Elem()) { if isByte(u.Elem()) {
codegen.genSlice1(path, u) codegen.genSlice1(path, u)
} else { } else {
codegen.genSlice(path, u, obj) codegen.genSlice(path, u, obj)
...@@ -1242,6 +1880,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene ...@@ -1242,6 +1880,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
codegen.genMap(path, u, obj) codegen.genMap(path, u, obj)
case *types.Pointer: case *types.Pointer:
panic("XXX") // XXX what here?
default: default:
log.Fatalf("%v: %v has unsupported type %v (%v)", pos(obj), log.Fatalf("%v: %v has unsupported type %v (%v)", pos(obj),
...@@ -1255,8 +1894,14 @@ func generateCodecCode(typespec *ast.TypeSpec, codegen CodeGenerator) string { ...@@ -1255,8 +1894,14 @@ func generateCodecCode(typespec *ast.TypeSpec, codegen CodeGenerator) string {
typ := typeInfo.Types[typespec.Type].Type typ := typeInfo.Types[typespec.Type].Type
obj := typeInfo.Defs[typespec.Name] obj := typeInfo.Defs[typespec.Name]
codegen.setFunc("p", typespec.Name.Name, typ) codegen.setFunc("p", typespec.Name.Name, typ, codegen.encoding())
codegenType("p", typ, obj, codegen) codegenType("p", typ, obj, codegen)
return codegen.generatedCode() return codegen.generatedCode()
} }
// isByte returns whether typ represents byte.
func isByte(typ types.Type) bool {
t, ok := typ.(*types.Basic)
return ok && t.Kind() == types.Byte
}
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -387,6 +387,8 @@ func Verify(t *testing.T, f func(*tEnv)) { ...@@ -387,6 +387,8 @@ func Verify(t *testing.T, f func(*tEnv)) {
// TODO verify M=(go|py) x S=(go|py) x ... // TODO verify M=(go|py) x S=(go|py) x ...
// for now we only verify for all combinations of network // for now we only verify for all combinations of network
// TODO verify enc=(M|N)
// for all networks // for all networks
for _, network := range []string{"pipenet", "lonet"} { for _, network := range []string{"pipenet", "lonet"} {
opt := tClusterOptions{ opt := tClusterOptions{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment