Commit 1cf2712f authored by Kirill Smelkov's avatar Kirill Smelkov

X on msgpack support

parent d2697535
......@@ -27,6 +27,7 @@ package xcontext
import (
"context"
"errors"
"io"
)
// Cancelled reports whether an error is due to a canceled context.
......@@ -72,3 +73,41 @@ func WhenDone(ctx context.Context, f func()) func() {
close(done)
}
}
// WithCloseOnErrCancel closes c on ctx cancel while f is run, or if f returns with an error.
//
// It is usually handy to propagate cancellation to interrupt IO.
// XXX naming?
// XXX don't close on f return?
func WithCloseOnErrCancel(ctx context.Context, c io.Closer, f func() error) (err error) {
closed := false
fdone := make(chan error)
defer func() {
<-fdone // wait for f to complete
if err != nil {
if !closed {
c.Close()
}
}
}()
go func() (err error) {
defer func() {
fdone <- err
close(fdone)
}()
return f()
}()
select {
case <-ctx.Done():
c.Close() // interrupt IO
closed = true
return ctx.Err()
case err := <-fdone:
return err
}
}
......@@ -489,6 +489,7 @@ func withNEO(t *testing.T, f func(t *testing.T, nsrv NEOSrv, ndrv *Client), optv
withNEOSrv(t, func(t *testing.T, nsrv NEOSrv) {
t.Helper()
X := xtesting.FatalIf(t)
// TODO test for enc=(M|N) (XXX M|N only for NEO/go as NEO/py does not support autodetect)
ndrv, _, err := neoOpen(nsrv.URL(),
&zodb.DriverOptions{ReadOnly: true}); X(err)
defer func() {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
// Package msgpack complements tinylib/msgp in providing runtime support for MessagePack.
//
// https://github.com/msgpack/msgpack/blob/master/spec.md
package msgpack
import (
"encoding/binary"
"math"
)
// Op represents a MessagePack opcode.
type Op byte
const (
FixMap_4 Op = 0b1000_0000 // 1000_XXXX
FixArray_4 Op = 0b1001_0000 // 1001_XXXX
False Op = 0xc2
True Op = 0xc3
Bin8 Op = 0xc4
Bin16 Op = 0xc5
Bin32 Op = 0xc6
Float32 Op = 0xca
Float64 Op = 0xcb
Uint8 Op = 0xcc
Uint16 Op = 0xcd
Uint32 Op = 0xce
Uint64 Op = 0xcf
Int8 Op = 0xd0
Int16 Op = 0xd1
Int32 Op = 0xd2
Int64 Op = 0xd3
FixExt1 Op = 0xd4
FixExt2 Op = 0xd5
FixExt4 Op = 0xd6
Array16 Op = 0xdc
Array32 Op = 0xdd
Map16 Op = 0xde
Map32 Op = 0xdf
)
// op converts Op into byte.
// it is used internally to make sure that only Op is put into encoded data.
func op(x Op) byte {
return byte(x)
}
// Bool returns op corresponding to bool value v.
func Bool(v bool) Op {
if v {
return True
} else {
return False
}
}
// u?intXSize(i) returns size needed to encode i.
func Int8Size (i int8) int { return Int64Size(int64(i)) }
func Int16Size(i int16) int { return Int64Size(int64(i)) }
func Int32Size(i int32) int { return Int64Size(int64(i)) }
func Uint8Size (i uint8) int { return Uint64Size(uint64(i)) }
func Uint16Size(i uint16) int { return Uint64Size(uint64(i)) }
func Uint32Size(i uint32) int { return Uint64Size(uint64(i)) }
// Putu?intX(data, i X) encodes i into data and returns encoded size.
func PutInt8 (data []byte, i int8) int { return PutInt64(data, int64(i)) }
func PutInt16(data []byte, i int16) int { return PutInt64(data, int64(i)) }
func PutInt32(data []byte, i int32) int { return PutInt64(data, int64(i)) }
func PutUint8 (data []byte, i uint8) int { return PutUint64(data, uint64(i)) }
func PutUint16(data []byte, i uint16) int { return PutUint64(data, uint64(i)) }
func PutUint32(data []byte, i uint32) int { return PutUint64(data, uint64(i)) }
func Int64Size(i int64) int {
switch {
case -32 <= i && i <= 0b0_1111111: return 1 // posfixint | negfixint
case int64(int8(i)) == i: return 1+1 // int8 + i8
case int64(int16(i)) == i: return 1+2 // int16 + i16
case int64(int32(i)) == i: return 1+4 // int32 + i32
default: return 1+8 // int64 + u64
}
}
func PutInt64(data []byte, i int64) (n int) {
switch {
// posfixint | negfixint
case -32 <= i && i <= 0b0_1111111:
data[0] = uint8(i)
return 1
// int8 + s8
case int64(int8(i)) == i:
data[0] = op(Int8)
data[1] = uint8(i)
return 1+1
// int16 + s16
case int64(int16(i)) == i:
data[0] = op(Int16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// int32 + s32
case int64(int32(i)) == i:
data[0] = op(Int32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// int64 + s64
default:
data[0] = op(Int64)
binary.BigEndian.PutUint64(data[1:], uint64(i))
return 1+8
}
}
func Uint64Size(i uint64) int {
switch {
case i <= 0x7f: return 1 // posfixint
case i <= 0xff: return 1+1 // uint8 + u8
case i <= 0xffff: return 1+2 // uint16 + u16
case i <= 0xffffffff: return 1+4 // uint32 + u32
default: return 1+8 // uint64 + u64
}
}
func PutUint64(data []byte, i uint64) (n int) {
switch {
// posfixint
case i <= 0x7f:
data[0] = uint8(i)
return 1
// uint8 + u8
case i <= math.MaxUint8:
data[0] = op(Uint8)
data[1] = uint8(i)
return 1+1
// uint16 + be16
case i <= math.MaxUint16:
data[0] = op(Uint16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// uint32 + be32
case i <= math.MaxUint32:
data[0] = op(Uint32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// uint64 + be64
default:
data[0] = op(Uint64)
binary.BigEndian.PutUint64(data[1:], i)
return 1+8
}
}
// BinHeadSize return number of bytes needed to encode header for [l]bin.
func BinHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= math.MaxUint8: return 1+1 // bin8 + len8
case l <= math.MaxUint16: return 1+2 // bin16 + len16
case l <= math.MaxUint32: return 1+4 // bin32 + len32
default: panic("len overflows uint32")
}
}
// PutBinHead puts binary header for [size]bin.
func PutBinHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// bin8 + len8
case l <= 0xff:
data[0] = op(Bin8)
data[1] = uint8(l)
return 1+1
// bin16 + len16
case l <= math.MaxUint16:
data[0] = op(Bin16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// bin32 + len32
case l <= math.MaxUint32:
data[0] = op(Bin32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// ArrayHeadSize returns size for array header for [size]array.
func ArrayHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= 0x0f: return 1 // fixarray
case l <= math.MaxUint16: return 1+2 // array16 + len16
case l <= math.MaxUint32: return 1+4 // array32 + len32
default: panic("len overflows uint32")
}
}
// PutArrayHead puts array header for [size]array.
func PutArrayHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixarray
case l <= 0x0f:
data[0] = op(FixArray_4 | Op(l))
return 1
// array16 + len16
case l <= math.MaxUint16:
data[0] = op(Array16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// array32 + len32
case l <= math.MaxUint32:
data[0] = op(Array32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// MapHeadSize returns size for map header for [size]map.
func MapHeadSize(l int) int {
return ArrayHeadSize(l) // the same 0x0f/len16/len32 scheme
}
// PutMapHead puts map header for [size]map.
func PutMapHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixmap
case l <= 0x0f:
data[0] = op(FixMap_4 | Op(l))
return 1
// map16 + len16
case l <= math.MaxUint16:
data[0] = op(Map16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// map32 + len32
case l <= math.MaxUint32:
data[0] = op(Map32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package msgpack
import (
hexpkg "encoding/hex"
"testing"
)
// hex decodes string as hex; panics on error.
func hex(s string) string {
b, err := hexpkg.DecodeString(s)
if err != nil {
panic(err)
}
return string(b)
}
// tGetPutSize is interface with Get/Put/Size methods, e.g. with
// GetBinHead/PutBinHead/BinHeadSize.
type tGetPutSize interface {
// XXX Get(data []byte) (n int, ret interface{})
Size(arg interface{}) int
Put(data []byte, arg interface{}) int
}
// test1 verifies enc functions on one argument.
func test1(t *testing.T, enc tGetPutSize, arg interface{}, encoded string) {
t.Helper()
data := make([]byte, 16)
n := enc.Put(data, arg)
got := string(data[:n])
if got != encoded {
t.Errorf("%v -> %x ; want %x", arg, got, encoded)
}
if sz := enc.Size(arg); sz != n {
t.Errorf("size(%v) -> %d ; len(data)=%d", arg, sz, n)
}
// XXX decode == arg, n
// XXX decode([:n-1]) -> overflow
}
type tEncUint64 struct{}
func (_ *tEncUint64) Size(xi interface{}) int { return Uint64Size(xi.(uint64)) }
func (_ *tEncUint64) Put(data []byte, xi interface{}) int { return PutUint64(data, xi.(uint64)) }
func TestUint(t *testing.T) {
h := hex
testv := []struct{i uint64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{0x80, h("cc80")}, // uint8
{0xff, h("ccff")},
{0x100, h("cd0100")}, // uint16
{0xffff, h("cdffff")},
{0x10000, h("ce00010000")}, // uint32
{0xffffffff, h("ceffffffff")},
{0x100000000, h("cf0000000100000000")}, // uint64
{0xffffffffffffffff, h("cfffffffffffffffff")},
}
for _, tt := range testv {
test1(t, &tEncUint64{}, tt.i, tt.encoded)
}
}
type tEncInt64 struct{}
func (_ *tEncInt64) Size(xi interface{}) int { return Int64Size(xi.(int64)) }
func (_ *tEncInt64) Put(data []byte, xi interface{}) int { return PutInt64(data, xi.(int64)) }
func TestInt(t *testing.T) {
h := hex
testv := []struct{i int64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{-1, h("ff")}, // negfixint
{-2, h("fe")},
{-31, h("e1")},
{-32, h("e0")},
{-33, h("d0df")}, // int8
{-0x7f, h("d081")},
{-0x80, h("d080")},
{0x80, h("d10080")}, // int16
{0x7fff, h("d17fff")},
{-0x7fff, h("d18001")},
{-0x8000, h("d18000")},
{0x8000, h("d200008000")}, // int32
{0x7fffffff, h("d27fffffff")},
{-0x8001, h("d2ffff7fff")},
{-0x7fffffff, h("d280000001")},
{-0x80000000, h("d280000000")},
{0x80000000, h("d30000000080000000")}, // int64
{0x7fffffffffffffff, h("d37fffffffffffffff")},
{-0x80000001, h("d3ffffffff7fffffff")},
{-0x7fffffffffffffff, h("d38000000000000001")},
{-0x8000000000000000, h("d38000000000000000")},
}
for _, tt := range testv {
test1(t, &tEncInt64{}, tt.i, tt.encoded)
}
}
type tEncBinHead struct{}
func (_ *tEncBinHead) Size(xl interface{}) int { return BinHeadSize(xl.(int)) }
func (_ *tEncBinHead) Put(data []byte, xl interface{}) int { return PutBinHead(data, xl.(int)) }
func TestBin(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("c400")}, // bin8
{1, h("c401")},
{0xff, h("c4ff")},
{0x100, h("c50100")}, // bin16
{0xffff, h("c5ffff")},
{0x10000, h("c600010000")}, // bin32
{0xffffffff, h("c6ffffffff")},
}
for _, tt := range testv {
test1(t, &tEncBinHead{}, tt.l, tt.encoded)
}
}
type tEncArrayHead struct{}
func (_ *tEncArrayHead) Size(xl interface{}) int { return ArrayHeadSize(xl.(int)) }
func (_ *tEncArrayHead) Put(data []byte, xl interface{}) int { return PutArrayHead(data, xl.(int)) }
func TestArray(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("90")}, // fixarray
{1, h("91")},
{14, h("9e")},
{15, h("9f")},
{0x10, h("dc0010")}, // array16
{0x11, h("dc0011")},
{0x100, h("dc0100")},
{0xffff, h("dcffff")},
{0x10000, h("dd00010000")}, // array32
{0xffffffff, h("ddffffffff")},
}
for _, tt := range testv {
test1(t, &tEncArrayHead{}, tt.l, tt.encoded)
}
}
type tEncMapHead struct{}
func (_ *tEncMapHead) Size(xl interface{}) int { return MapHeadSize(xl.(int)) }
func (_ *tEncMapHead) Put(data []byte, xl interface{}) int { return PutMapHead(data, xl.(int)) }
func TestMap(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("80")}, // fixmap
{1, h("81")},
{14, h("8e")},
{15, h("8f")},
{0x10, h("de0010")}, // map16
{0x11, h("de0011")},
{0x100, h("de0100")},
{0xffff, h("deffff")},
{0x10000, h("df00010000")}, // map32
{0xffffffff, h("dfffffffff")},
}
for _, tt := range testv {
test1(t, &tEncMapHead{}, tt.l, tt.encoded)
}
}
......@@ -102,9 +102,12 @@ import (
"lab.nexedi.com/kirr/neo/go/internal/packed"
"lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
"lab.nexedi.com/kirr/neo/go/neo/proto"
"github.com/philhofer/fwd"
"github.com/someonegg/gocontainer/rbuf"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xbytes"
)
......@@ -124,7 +127,8 @@ import (
//
// It is safe to use NodeLink from multiple goroutines simultaneously.
type NodeLink struct {
peerLink net.Conn // raw conn to peer
peerLink net.Conn // raw conn to peer
enc proto.Encoding // protocol encoding in use ('N' or 'M')
connMu sync.Mutex
connTab map[uint32]*Conn // connId -> Conn associated with connId
......@@ -153,7 +157,8 @@ type NodeLink struct {
axclosed atomic32 // whether CloseAccept was called
closed atomic32 // whether Close was called
rxbuf rbuf.RingBuf // buffer for reading from peerLink
rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding)
rxbufM *msgp.Reader // ----//---- (M encoding)
// scheduling optimization: whenever serveRecv sends to Conn.rxq
// receiving side must ack here to receive G handoff.
......@@ -250,6 +255,8 @@ const (
// newNodeLink makes a new NodeLink from already established net.Conn .
//
// On the wire messages will be encoded according to enc.
//
// Role specifies how to treat our role on the link - either as client or
// server. The difference in between client and server roles is in:
//
......@@ -262,7 +269,9 @@ const (
//
// Though it is possible to wrap just-established raw connection into NodeLink,
// users should always use Handshake which performs protocol handshaking first.
func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
//
// rxbuf if != nil indicates what was already read-buffered from conn.
func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.Reader) *NodeLink {
var nextConnId uint32
switch role &^ linkFlagsMask {
case _LinkServer:
......@@ -275,6 +284,7 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
nl := &NodeLink{
peerLink: conn,
enc: enc,
connTab: map[uint32]*Conn{},
nextConnId: nextConnId,
acceptq: make(chan *Conn), // XXX +buf ?
......@@ -283,6 +293,25 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
// axdown: make(chan struct{}),
down: make(chan struct{}),
}
if rxbuf == nil {
rxbuf = fwd.NewReader(conn)
}
switch enc {
case 'N':
// rxbufN <- rxbufM (what was preread)
b, err := rxbuf.Next(rxbuf.Buffered())
if err != nil {
panic(err) // must not fail
}
nl.rxbufN.Write(b)
case 'M':
nl.rxbufM = &msgp.Reader{R: rxbuf}
default:
panic("bug")
}
if role&linkNoRecvSend == 0 {
nl.serveWg.Add(2)
go nl.serveRecv()
......@@ -1038,12 +1067,14 @@ func (c *Conn) sendPkt(pkt *pktBuf) error {
func (c *Conn) sendPkt2(pkt *pktBuf) error {
// connId must be set to one associated with this connection
if pkt.Header().ConnId != packed.Hton32(c.connId) {
connID, _, _, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
}
if connID != c.connId {
panic("Conn.sendPkt: connId wrong")
}
var err error
select {
case <-c.txdown:
return c.errSendShutdown()
......@@ -1173,8 +1204,24 @@ var ErrPktTooBig = errors.New("packet too big")
// rx error, if any, is returned as is and is analyzed in serveRecv
//
// XXX dup in ZEO.
func (nl *NodeLink) recvPkt() (*pktBuf, error) {
// FIXME if rxbuf is non-empty - first look there for header and then if
func (nl *NodeLink) recvPkt() (pkt *pktBuf, err error) {
switch nl.enc {
case 'N': pkt, err = nl.recvPktN()
case 'M': pkt, err = nl.recvPktM()
default: panic("bug")
}
if dumpio {
// XXX -> log
fmt.Printf("%v < %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
}
return pkt, err
}
func (nl *NodeLink) recvPktN() (*pktBuf, error) {
// FIXME if rxbufN is non-empty - first look there for header and then if
// we know size -> allocate pkt with that size.
pkt := pktAlloc(4096)
// len=4K but cap can be more since pkt is from pool - use all space to buffer reads
......@@ -1184,8 +1231,8 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n := 0 // number of pkt bytes obtained so far
// next packet could be already prefetched in part by previous read
if nl.rxbuf.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[:proto.PktHeaderLen])
if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbufN.Read(data[:proto.PktHeaderLen])
n += δn
}
......@@ -1198,7 +1245,7 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n += δn
}
pkth := pkt.Header()
pkth := pkt.HeaderN()
msgLen := packed.Ntoh32(pkth.MsgLen)
if msgLen > proto.PktMaxSize - proto.PktHeaderLen {
......@@ -1210,9 +1257,9 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
data = xbytes.Resize(data, pktLen)
data = data[:cap(data)]
// we might have more data already prefetched in rxbuf
if nl.rxbuf.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[n:pktLen])
// we might have more data already prefetched in rxbufN
if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbufN.Read(data[n:pktLen])
n += δn
}
......@@ -1225,20 +1272,26 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n += δn
}
// put overread data into rxbuf for next reader
// put overread data into rxbufN for next reader
if n > pktLen {
nl.rxbuf.Write(data[pktLen:n])
nl.rxbufN.Write(data[pktLen:n])
}
// fixup data/pkt
data = data[:n]
pkt.data = data
if dumpio {
// XXX -> log
fmt.Printf("%v < %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
}
return pkt, nil
}
func (nl *NodeLink) recvPktM() (*pktBuf, error) {
pkt := pktAlloc(4096)
mraw := msgp.Raw(pkt.data)
err := mraw.DecodeMsg(nl.rxbufM) // XXX limit size of one packet to e.g. 0x4000000 (= UNPACK_BUFFER_SIZE in NEO/py speak)
if err != nil {
return nil, err
}
pkt.data = []byte(mraw)
return pkt, nil
}
......@@ -1313,21 +1366,104 @@ func (c *Conn) err(op string, e error) error {
//trace:event traceMsgSendPre(l *NodeLink, connId uint32, msg proto.Msg)
// XXX do we also need traceConnSend?
// msgPack allocates pktBuf and encodes msg into it.
func msgPack(connId uint32, msg proto.Msg) *pktBuf {
l := msg.NEOMsgEncodedLen()
// XXX think again; XXX move to proto?
const (
encN = proto.Encoding('N')
encM = proto.Encoding('M')
)
// pktEncode allocates pktBuf and encodes msg into it.
//func (e Encoding) pktEncode(connId uint32, msg proto.Msg) *pktBuf {
// XXX move to proto ? -> YES: Encoding.PktEncode + .PktDecode
func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf {
switch e {
case 'N': return pktEncodeN(connId, msg)
case 'M': return pktEncodeM(connId, msg)
default: panic("bug")
}
}
// pktDecodeHead decodes header of a packet.
func pktDecodeHead(e proto.Encoding, pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
switch e {
case 'N': return pktDecodeHeadN(pkt)
case 'M': return pktDecodeHeadM(pkt)
default: panic("bug")
}
}
func pktEncodeN(connId uint32, msg proto.Msg) *pktBuf {
l := encN.NEOMsgEncodedLen(msg)
buf := pktAlloc(proto.PktHeaderLen + l)
h := buf.Header()
h.ConnId = packed.Hton32(connId)
h.MsgCode = packed.Hton16(msg.NEOMsgCode())
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
h := buf.HeaderN()
h.ConnId = packed.Hton32(connId)
h.MsgCode = packed.Hton16(proto.MsgCode(msg))
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
msg.NEOMsgEncode(buf.Payload())
encN.NEOMsgEncode(msg, buf.PayloadN())
return buf
}
// TODO msgUnpack
func pktDecodeHeadN(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
pkth := pkt.HeaderN()
connID = packed.Ntoh32(pkth.ConnId)
msgCode = packed.Ntoh16(pkth.MsgCode)
msgLen := packed.Ntoh32(pkth.MsgLen)
payload = pkt.PayloadN()
if len(payload) != int(msgLen) {
return 0, 0, nil, fmt.Errorf("len(payload) != msgLen")
}
return
}
func pktEncodeM(connId uint32, msg proto.Msg) *pktBuf {
// [3](connID, msgCode, argv)
msgCode := proto.MsgCode(msg)
hroom := msgpack.ArrayHeadSize(3) +
msgpack.Uint32Size(connId) +
msgpack.Uint16Size(msgCode)
l := encM.NEOMsgEncodedLen(msg)
buf := pktAlloc(hroom + l)
b := buf.data
i := 0
i += msgpack.PutArrayHead (b[i:], 3)
i += msgpack.PutUint32 (b[i:], connId)
i += msgpack.PutUint16 (b[i:], msgCode)
if i != hroom {
panic("bug")
}
encM.NEOMsgEncode(msg, b[hroom:])
return buf
}
func pktDecodeHeadM(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
// XXX errctx = "decode header"
b := pkt.data
sz, b, err := msgp.ReadArrayHeaderBytes(b)
if err != nil {
return 0, 0, nil, err
}
if sz != 3 {
return 0, 0, nil, fmt.Errorf("expected [3]tuple, got [%d]tuple", sz)
}
connID, b, err = msgp.ReadUint32Bytes(b)
if err != nil {
return 0, 0, nil, fmt.Errorf("connID: %s", err)
}
msgCode, b, err = msgp.ReadUint16Bytes(b)
if err != nil {
return 0, 0, nil, fmt.Errorf("msgCode: %s", err)
}
return connID, msgCode, b, nil
}
// Recv receives message from the connection.
func (c *Conn) Recv() (proto.Msg, error) {
......@@ -1338,8 +1474,11 @@ func (c *Conn) Recv() (proto.Msg, error) {
defer pkt.Free()
// decode packet
pkth := pkt.Header()
msgCode := packed.Ntoh16(pkth.MsgCode)
_, msgCode, payload, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
return nil, err
}
msgType := proto.MsgType(msgCode)
if msgType == nil {
err := fmt.Errorf("invalid msgCode (%d)", msgCode)
......@@ -1352,7 +1491,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
// msg := reflect.NewAt(msgType, bufAlloc(msgType.Size())
_, err = msg.NEOMsgDecode(pkt.Payload())
_, err = c.link.enc.NEOMsgDecode(msg, payload)
if err != nil {
return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow
}
......@@ -1369,7 +1508,8 @@ func (c *Conn) Recv() (proto.Msg, error) {
func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error {
traceMsgSendPre(link, connId, msg)
buf := msgPack(connId, msg)
//buf := msgPack(connId, msg)
buf := pktEncode(link.enc, connId, msg)
return link.sendPkt(buf) // XXX more context in err? (msg type)
// FIXME ^^^ shutdown whole link on error
}
......@@ -1378,7 +1518,8 @@ func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error {
func (c *Conn) Send(msg proto.Msg) error {
traceMsgSendPre(c.link, c.connId, msg)
buf := msgPack(c.connId, msg)
//buf := msgPack(c.connId, msg)
buf := pktEncode(c.link.enc, c.connId, msg)
return c.sendPkt(buf) // XXX more context in err? (msg type)
}
......@@ -1401,12 +1542,13 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) {
}
defer pkt.Free()
// XXX encN-specific
pkth := pkt.Header()
msgCode := packed.Ntoh16(pkth.MsgCode)
for i, msg := range msgv {
if msg.NEOMsgCode() == msgCode {
_, err := msg.NEOMsgDecode(pkt.Payload())
if proto.MsgCode(msg) == msgCode {
_, err := c.link.enc.NEOMsgDecode(msg, pkt.Payload())
if err != nil {
return -1, c.err("decode", err)
}
......
......@@ -22,6 +22,7 @@ package neonet
import (
"bytes"
"context"
"fmt"
"io"
"net"
"reflect"
......@@ -38,10 +39,30 @@ import (
"lab.nexedi.com/kirr/neo/go/neo/proto"
"lab.nexedi.com/kirr/neo/go/zodb"
"github.com/tinylib/msgp/msgp"
"github.com/kylelemons/godebug/pretty"
"github.com/pkg/errors"
)
// T is neonet testing environment.
type T struct {
*testing.T
enc proto.Encoding // encoding to use for messages exchange
}
// Verify tests f for all possible environments.
func Verify(t *testing.T, f func(*T)) {
// for each encoding
for _, enc := range []proto.Encoding{'N', 'M'} {
t.Run(fmt.Sprintf("enc=%c", enc), func(t *testing.T) {
f(&T{t, enc})
})
}
}
func xclose(c io.Closer) {
err := c.Close()
exc.Raiseif(err)
......@@ -102,48 +123,70 @@ func xconnError(err error) error {
}
// Prepare pktBuf with content.
func _mkpkt(connid uint32, msgcode uint16, payload []byte) *pktBuf {
pkt := &pktBuf{make([]byte, proto.PktHeaderLen+len(payload))}
h := pkt.Header()
h.ConnId = packed.Hton32(connid)
h.MsgCode = packed.Hton16(msgcode)
h.MsgLen = packed.Hton32(uint32(len(payload)))
copy(pkt.Payload(), payload)
return pkt
func _mkpkt(enc proto.Encoding, connid uint32, msgcode uint16, payload []byte) *pktBuf {
switch enc {
case 'N':
pkt := &pktBuf{make([]byte, proto.PktHeaderLen+len(payload))}
h := pkt.HeaderN()
h.ConnId = packed.Hton32(connid)
h.MsgCode = packed.Hton16(msgcode)
h.MsgLen = packed.Hton32(uint32(len(payload)))
copy(pkt.PayloadN(), payload)
return pkt
case 'M':
var b []byte
b = msgp.AppendArrayHeader (b, 3)
b = msgp.AppendUint32 (b, connid)
b = msgp.AppendUint16 (b, msgcode)
// NOTE payload is appended wrapped into bin object. We need
// this not to break framing, because in M-encoding whole
// packet must be a valid msgpack object.
b = msgp.AppendBytes (b, payload)
return &pktBuf{b}
default:
panic("bug")
}
}
func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf {
// in Conn exchange connid is automatically set by Conn.sendPkt
return _mkpkt(c.connId, msgcode, payload)
return _mkpkt(c.link.enc, c.connId, msgcode, payload)
}
// Verify pktBuf is as expected.
func xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) {
func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) {
errv := xerr.Errorv{}
h := pkt.Header()
pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(t.enc, pkt)
exc.Raiseif(err)
// TODO include caller location
if packed.Ntoh32(h.ConnId) != connid {
errv.Appendf("header: unexpected connid %v (want %v)", packed.Ntoh32(h.ConnId), connid)
if pktConnID != connid {
errv.Appendf("header: unexpected connid %v (want %v)", pktConnID, connid)
}
if packed.Ntoh16(h.MsgCode) != msgcode {
errv.Appendf("header: unexpected msgcode %v (want %v)", packed.Ntoh16(h.MsgCode), msgcode)
if pktMsgCode != msgcode {
errv.Appendf("header: unexpected msgcode %v (want %v)", pktMsgCode, msgcode)
}
if packed.Ntoh32(h.MsgLen) != uint32(len(payload)) {
errv.Appendf("header: unexpected msglen %v (want %v)", packed.Ntoh32(h.MsgLen), len(payload))
// M-encoding -> wrap payloadOK into bin (see _mkpkt ^^^ for why)
if t.enc == 'M' {
payload = msgp.AppendBytes(nil, payload)
}
if !bytes.Equal(pkt.Payload(), payload) {
if !bytes.Equal(pktPayload, payload) {
errv.Appendf("payload differ:\n%s",
pretty.Compare(string(payload), string(pkt.Payload())))
pretty.Compare(string(payload), string(pktPayload)))
}
exc.Raiseif(errv.Err())
}
// Verify pktBuf to match expected message.
func xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) {
data := make([]byte, msg.NEOMsgEncodedLen())
msg.NEOMsgEncode(data)
xverifyPkt(pkt, connid, msg.NEOMsgCode(), data)
func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) {
data := make([]byte, t.enc.NEOMsgEncodedLen(msg))
t.enc.NEOMsgEncode(msg, data)
t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data)
}
// delay a bit.
......@@ -160,24 +203,27 @@ func tdelay() {
time.Sleep(1 * time.Millisecond)
}
// create NodeLinks connected via net.Pipe
func _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) {
// create NodeLinks connected via net.Pipe; messages are encoded via t.enc.
func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) {
node1, node2 := net.Pipe()
nl1 = newNodeLink(node1, _LinkClient|flags1)
nl2 = newNodeLink(node2, _LinkServer|flags2)
nl1 = newNodeLink(node1, t.enc, _LinkClient|flags1, nil)
nl2 = newNodeLink(node2, t.enc, _LinkServer|flags2, nil)
return nl1, nl2
}
func nodeLinkPipe() (nl1, nl2 *NodeLink) {
return _nodeLinkPipe(0, 0)
func (t *T) nodeLinkPipe() (nl1, nl2 *NodeLink) {
return t._nodeLinkPipe(0, 0)
}
func TestNodeLink(t *testing.T) {
Verify(t, _TestNodeLink)
}
func _TestNodeLink(t *T) {
// TODO catch exception -> add proper location from it -> t.Fatal (see git-backup)
bg := context.Background()
// Close vs recvPkt
nl1, nl2 := _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
nl1, nl2 := t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg := xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
tdelay()
......@@ -191,7 +237,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2)
// Close vs sendPkt
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
tdelay()
......@@ -206,7 +252,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2)
// {Close,CloseAccept} vs Accept
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
tdelay()
......@@ -234,7 +280,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl1)
// Close vs recvPkt on another side
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
tdelay()
......@@ -248,7 +294,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl1)
// Close vs sendPkt on another side
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
tdelay()
......@@ -263,23 +309,23 @@ func TestNodeLink(t *testing.T) {
xclose(nl1)
// raw exchange
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, linkNoRecvSend)
wg = xsync.NewWorkGroup(bg)
okch := make(chan int, 2)
gox(wg, func(_ context.Context) {
// send ping; wait for pong
pkt := _mkpkt(1, 2, []byte("ping"))
pkt := _mkpkt(t.enc, 1, 2, []byte("ping"))
xsendPkt(nl1, pkt)
pkt = xrecvPkt(nl1)
xverifyPkt(pkt, 3, 4, []byte("pong"))
t.xverifyPkt(pkt, 3, 4, []byte("pong"))
okch <- 1
})
gox(wg, func(_ context.Context) {
// wait for ping; send pong
pkt = xrecvPkt(nl2)
xverifyPkt(pkt, 1, 2, []byte("ping"))
pkt = _mkpkt(3, 4, []byte("pong"))
t.xverifyPkt(pkt, 1, 2, []byte("ping"))
pkt = _mkpkt(t.enc, 3, 4, []byte("pong"))
xsendPkt(nl2, pkt)
okch <- 2
})
......@@ -309,7 +355,7 @@ func TestNodeLink(t *testing.T) {
// ---- connections on top of nodelink ----
// Close vs recvPkt
nl1, nl2 = _nodeLinkPipe(0, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(0, linkNoRecvSend)
c = xnewconn(nl1)
wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
......@@ -325,7 +371,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2)
// Close vs sendPkt
nl1, nl2 = _nodeLinkPipe(0, linkNoRecvSend)
nl1, nl2 = t._nodeLinkPipe(0, linkNoRecvSend)
c = xnewconn(nl1)
wg = xsync.NewWorkGroup(bg)
gox(wg, func(_ context.Context) {
......@@ -364,7 +410,7 @@ func TestNodeLink(t *testing.T) {
xclose(nl2)
// NodeLink.Close vs Conn.sendPkt/recvPkt and Accept on another side
nl1, nl2 = _nodeLinkPipe(linkNoRecvSend, 0)
nl1, nl2 = t._nodeLinkPipe(linkNoRecvSend, 0)
c21 := xnewconn(nl2)
c22 := xnewconn(nl2)
c23 := xnewconn(nl2)
......@@ -482,7 +528,7 @@ func TestNodeLink(t *testing.T) {
connKeepClosed = 10 * time.Millisecond
// Conn accept + exchange
nl1, nl2 = nodeLinkPipe()
nl1, nl2 = t.nodeLinkPipe()
nl1.CloseAccept()
wg = xsync.NewWorkGroup(bg)
closed := make(chan int)
......@@ -490,14 +536,14 @@ func TestNodeLink(t *testing.T) {
c := xaccept(nl2)
pkt := xrecvPkt(c)
xverifyPkt(pkt, c.connId, 33, []byte("ping"))
t.xverifyPkt(pkt, c.connId, 33, []byte("ping"))
// change pkt a bit and send it back
xsendPkt(c, c.mkpkt(34, []byte("pong")))
// one more time
pkt = xrecvPkt(c)
xverifyPkt(pkt, c.connId, 35, []byte("ping2"))
t.xverifyPkt(pkt, c.connId, 35, []byte("ping2"))
xsendPkt(c, c.mkpkt(36, []byte("pong2")))
xclose(c)
......@@ -506,7 +552,7 @@ func TestNodeLink(t *testing.T) {
// once again as ^^^ but finish only with CloseRecv
c2 := xaccept(nl2)
pkt = xrecvPkt(c2)
xverifyPkt(pkt, c2.connId, 41, []byte("ping5"))
t.xverifyPkt(pkt, c2.connId, 41, []byte("ping5"))
xsendPkt(c2, c2.mkpkt(42, []byte("pong5")))
c2.CloseRecv()
......@@ -516,10 +562,10 @@ func TestNodeLink(t *testing.T) {
c = xnewconn(nl2) // XXX should get error here?
xsendPkt(c, c.mkpkt(38, []byte("pong3")))
pkt = xrecvPkt(c)
xverifyPktMsg(pkt, c.connId, errConnRefused)
t.xverifyPktMsg(pkt, c.connId, errConnRefused)
xsendPkt(c, c.mkpkt(40, []byte("pong4"))) // once again
pkt = xrecvPkt(c)
xverifyPktMsg(pkt, c.connId, errConnRefused)
t.xverifyPktMsg(pkt, c.connId, errConnRefused)
xclose(c)
......@@ -528,30 +574,30 @@ func TestNodeLink(t *testing.T) {
c1 := xnewconn(nl1)
xsendPkt(c1, c1.mkpkt(33, []byte("ping")))
pkt = xrecvPkt(c1)
xverifyPkt(pkt, c1.connId, 34, []byte("pong"))
t.xverifyPkt(pkt, c1.connId, 34, []byte("pong"))
xsendPkt(c1, c1.mkpkt(35, []byte("ping2")))
pkt = xrecvPkt(c1)
xverifyPkt(pkt, c1.connId, 36, []byte("pong2"))
t.xverifyPkt(pkt, c1.connId, 36, []byte("pong2"))
// "connection closed" after peer closed its end
<-closed
xsendPkt(c1, c1.mkpkt(37, []byte("ping3")))
pkt = xrecvPkt(c1)
xverifyPktMsg(pkt, c1.connId, errConnClosed)
t.xverifyPktMsg(pkt, c1.connId, errConnClosed)
xsendPkt(c1, c1.mkpkt(39, []byte("ping4"))) // once again
pkt = xrecvPkt(c1)
xverifyPktMsg(pkt, c1.connId, errConnClosed)
t.xverifyPktMsg(pkt, c1.connId, errConnClosed)
// XXX also should get EOF on recv
// one more time but now peer does only .CloseRecv()
c2 := xnewconn(nl1)
xsendPkt(c2, c2.mkpkt(41, []byte("ping5")))
pkt = xrecvPkt(c2)
xverifyPkt(pkt, c2.connId, 42, []byte("pong5"))
t.xverifyPkt(pkt, c2.connId, 42, []byte("pong5"))
<-closed
xsendPkt(c2, c2.mkpkt(41, []byte("ping6")))
pkt = xrecvPkt(c2)
xverifyPktMsg(pkt, c2.connId, errConnClosed)
t.xverifyPktMsg(pkt, c2.connId, errConnClosed)
xwait(wg)
......@@ -577,7 +623,7 @@ func TestNodeLink(t *testing.T) {
connKeepClosed = saveKeepClosed
// test 2 channels with replies coming in reversed time order
nl1, nl2 = nodeLinkPipe()
nl1, nl2 = t.nodeLinkPipe()
wg = xsync.NewWorkGroup(bg)
replyOrder := map[uint16]struct { // "order" in which to process requests
start chan struct{} // processing starts when start chan is ready
......@@ -594,6 +640,7 @@ func TestNodeLink(t *testing.T) {
gox(wg, func(_ context.Context) {
pkt := xrecvPkt(c)
// XXX encN-specific
n := packed.Ntoh16(pkt.Header().MsgCode)
x := replyOrder[n]
......@@ -619,7 +666,7 @@ func TestNodeLink(t *testing.T) {
// replies must be coming in reverse order
xechoWait := func(c *Conn, msgCode uint16) {
pkt := xrecvPkt(c)
xverifyPkt(pkt, c.connId, msgCode, []byte(""))
t.xverifyPkt(pkt, c.connId, msgCode, []byte(""))
}
xechoWait(c2, 2)
xechoWait(c1, 1)
......@@ -663,10 +710,13 @@ func xverifyMsg(msg1, msg2 proto.Msg) {
}
func TestRecv1Mode(t *testing.T) {
Verify(t, _TestRecv1Mode)
}
func _TestRecv1Mode(t *T) {
bg := context.Background()
// Send1
nl1, nl2 := nodeLinkPipe()
nl1, nl2 := t.nodeLinkPipe()
wg := xsync.NewWorkGroup(bg)
sync := make(chan int)
gox(wg, func(_ context.Context) {
......@@ -730,7 +780,10 @@ func TestRecv1Mode(t *testing.T) {
//
// bug triggers under -race.
func TestLightCloseVsLinkShutdown(t *testing.T) {
nl1, nl2 := nodeLinkPipe()
Verify(t, _TestLightCloseVsLinkShutdown)
}
func _TestLightCloseVsLinkShutdown(t *T) {
nl1, nl2 := t.nodeLinkPipe()
wg := xsync.NewWorkGroup(context.Background())
c := xnewconn(nl1)
......
......@@ -21,18 +21,38 @@ package neonet
// link establishment
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"os"
"github.com/philhofer/fwd"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xerr"
"lab.nexedi.com/kirr/go123/xnet"
"lab.nexedi.com/kirr/neo/go/internal/xcontext"
"lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/proto"
)
// encDefault is default encoding to use.
// XXX we don't need this? (just set encDefault = 'M')
var encDefault = proto.Encoding('N') // XXX = 'M' instead?
func init() {
e := os.Getenv("NEO_ENCODING")
switch e {
case "": // not set
case "N": fallthrough
case "M": encDefault = proto.Encoding(e[0])
default:
fmt.Fprintf(os.Stderr, "E: $NEO_ENCODING=%q - invalid -> abort", e)
os.Exit(1)
}
}
// ---- Handshake ----
// XXX _Handshake may be needed to become public in case when we have already
......@@ -45,91 +65,215 @@ import (
// On success raw connection is returned wrapped into NodeLink.
// On error raw connection is closed.
func _Handshake(ctx context.Context, conn net.Conn, role _LinkRole) (nl *NodeLink, err error) {
err = handshake(ctx, conn, proto.Version)
enc := encDefault // default encoding
var rxbuf *fwd.Reader
switch role &^ linkFlagsMask {
case _LinkServer:
enc, rxbuf, err = handshakeServer(ctx, conn, proto.Version)
case _LinkClient:
enc, rxbuf, err = handshakeClient(ctx, conn, proto.Version, enc)
default:
panic("bug")
}
if err != nil {
return nil, err
}
// handshake ok -> NodeLink
return newNodeLink(conn, role), nil
return newNodeLink(conn, enc, role, rxbuf), nil
}
// _HandshakeError is returned when there is an error while performing handshake.
type _HandshakeError struct {
LocalRole _LinkRole
LocalAddr net.Addr
RemoteAddr net.Addr
Err error
}
func (e *_HandshakeError) Error() string {
return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error())
role := ""
switch e.LocalRole {
case _LinkServer: role = "server"
case _LinkClient: role = "client"
default: panic("bug")
}
return fmt.Sprintf("%s - %s: handshake (%s): %s", e.LocalAddr, e.RemoteAddr, role, e.Err.Error())
}
func handshake(ctx context.Context, conn net.Conn, version uint32) (err error) {
// XXX simplify -> errgroup
errch := make(chan error, 2)
// tx handshake word
txWg := sync.WaitGroup{}
txWg.Add(1)
go func() {
var b [4]byte
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ?
_, err := conn.Write(b[:])
// XXX EOF -> ErrUnexpectedEOF ?
errch <- err
txWg.Done()
}()
func (e *_HandshakeError) Cause() error { return e.Err }
func (e *_HandshakeError) Unwrap() error { return e.Err }
// rx handshake word
go func() {
var b [4]byte
_, err := io.ReadFull(conn, b[:])
err = xio.NoEOF(err) // can be returned with n = 0
if err == nil {
peerVersion := binary.BigEndian.Uint32(b[:]) // XXX -> ntoh32 ?
if peerVersion != version {
err = fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVersion, version)
}
// handshakeClient implements client-side handshake.
//
// Client indicates its version and preferred encoding, but accepts any
// encoding choosen to use by server.
func handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
defer func() {
if err != nil {
err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err}
}
errch <- err
}()
connClosed := false
defer func() {
// make sure our version is always sent on the wire, if possible,
// so that peer does not see just closed connection when on rx we see version mismatch.
//
// NOTE if cancelled tx goroutine will wake up without delay.
txWg.Wait()
rxbuf = fwd.NewReader(conn)
// don't forget to close conn if returning with error + add handshake err context
var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// tx client hello
err := txHello("tx hello", conn, version, encPrefer)
if err != nil {
err = &_HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err}
if !connClosed {
conn.Close()
}
return err
}
// rx server hello reply
var peerVer uint32
peerEnc, peerVer, err = rxHello("rx hello reply", rxbuf)
if err != nil {
return err
}
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
}
return nil
})
if err != nil {
return 0, nil, err
}
// use peer encoding (server should return the same, but we are ok if
// it asks to switch to different)
return peerEnc, rxbuf, nil
}
// handshakeServer implementss server-side handshake.
//
// Server verifies that its version matches Client and accepts client preferred encoding.
func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
defer func() {
if err != nil {
err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err}
}
}()
for i := 0; i < 2; i++ {
select {
case <-ctx.Done():
conn.Close() // interrupt IO
connClosed = true
return ctx.Err()
case err = <-errch:
if err != nil {
return err
}
rxbuf = fwd.NewReader(conn)
var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// rx client hello
var peerVer uint32
var err error
peerEnc, peerVer, err = rxHello("rx hello", rxbuf)
if err != nil {
return err
}
// tx server reply
//
// do it before version check so that client can also detect "version
// mismatch" instead of just getting "disconnect".
err = txHello("tx hello reply", conn, version, peerEnc)
if err != nil {
return err
}
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
}
return nil
})
if err != nil {
return 0, nil, err
}
return peerEnc, rxbuf, nil
}
func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) (err error) {
defer xerr.Context(&err, errctx)
var b []byte
switch enc {
case 'N':
// 00 00 00 <v>
b = make([]byte, 4)
if version > 0xff {
panic("encoding N supports versions only in range [0, 0xff]")
}
b[3] = uint8(version)
case 'M':
// (b"NEO", <V>) encoded as msgpack (= 92 c4 03 NEO int(<V>))
b = msgp.AppendArrayHeader(b, 2) // 92
b = msgp.AppendBytes(b, []byte("NEO")) // c4 03 NEO
b = msgp.AppendUint32(b, version) // u?intX version
default:
panic("bug")
}
_, err = conn.Write(b)
if err != nil {
return err
}
// handshaked ok
return nil
}
func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32, err error) {
defer xerr.Context(&err, errctx)
b := make([]byte, 4)
_, err = io.ReadFull(rx, b)
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
var peerEnc proto.Encoding
var peerVer uint32
badMagic := false
switch {
case bytes.Equal(b[:3], []byte{0,0,0}):
peerEnc = encN
peerVer = uint32(b[3])
case bytes.Equal(b, []byte{0x92, 0xc4, 3, 'N'}): // start of "fixarray<2> bin8 'N | EO' ...
b = append(b, []byte{0,0}...)
_, err = io.ReadFull(rx, b[4:])
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
if !bytes.Equal(b[4:], []byte{'E','O'}) {
badMagic = true
break
}
peerEnc = encM
rxM := msgp.Reader{R: rx}
peerVer, err = rxM.ReadUint32()
if err != nil {
return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx
}
default:
badMagic = true
}
if badMagic {
return 0, 0, fmt.Errorf("invalid magic %x", b)
}
return peerEnc, peerVer, nil
}
// ---- Dial & Listen at NodeLink level ----
......@@ -141,6 +285,8 @@ func DialLink(ctx context.Context, net xnet.Networker, addr string) (*NodeLink,
return nil, err
}
// TODO if handshake fails with "closed" (= might be unexpected encoding)
// -> try redial and handshaking with different encoding (= autodetect encoding)
return _Handshake(ctx, peerConn, _LinkClient)
}
......
......@@ -21,29 +21,47 @@ package neonet
import (
"context"
"errors"
"io"
"net"
"testing"
"lab.nexedi.com/kirr/go123/exc"
"lab.nexedi.com/kirr/go123/xsync"
"lab.nexedi.com/kirr/neo/go/neo/proto"
)
func xhandshake(ctx context.Context, c net.Conn, version uint32) {
err := handshake(ctx, c, version)
// xhandshakeClient handshakes as client with encPrefer encoding and verifies that server accepts it.
func xhandshakeClient(ctx context.Context, c net.Conn, version uint32, encPrefer proto.Encoding) {
enc, _, err := handshakeClient(ctx, c, version, encPrefer)
exc.Raiseif(err)
if enc != encPrefer {
exc.Raisef("enc (%c) != encPrefer (%c)", enc, encPrefer)
}
}
// xhandshakeServer handshakes as server and verifies negotiated encoding to be encOK.
func xhandshakeServer(ctx context.Context, c net.Conn, version uint32, encOK proto.Encoding) {
enc, _, err := handshakeServer(ctx, c, version)
exc.Raiseif(err)
if enc != encOK {
exc.Raisef("enc (%c) != encOK (%c)", enc, encOK)
}
}
func TestHandshake(t *testing.T) {
Verify(t, _TestHandshake)
}
func _TestHandshake(t *T) {
bg := context.Background()
// handshake ok
p1, p2 := net.Pipe()
wg := xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) {
xhandshake(ctx, p1, 1)
xhandshakeClient(ctx, p1, 1, t.enc)
})
gox(wg, func(ctx context.Context) {
xhandshake(ctx, p2, 1)
xhandshakeServer(ctx, p2, 1, t.enc)
})
xwait(wg)
xclose(p1)
......@@ -54,17 +72,17 @@ func TestHandshake(t *testing.T) {
var err1, err2 error
wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1)
_, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
})
gox(wg, func(ctx context.Context) {
err2 = handshake(ctx, p2, 2)
_, _, err2 = handshakeServer(ctx, p2, 2)
})
xwait(wg)
xclose(p1)
xclose(p2)
err1Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000001 ; our side = 00000002"
err1Want := "pipe - pipe: handshake (client): protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake (server): protocol version mismatch: peer = 00000001 ; our side = 00000002"
if !(err1 != nil && err1.Error() == err1Want) {
t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want)
......@@ -78,7 +96,7 @@ func TestHandshake(t *testing.T) {
err1, err2 = nil, nil
wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1)
_, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
})
gox(wg, func(_ context.Context) {
xclose(p2)
......@@ -88,16 +106,20 @@ func TestHandshake(t *testing.T) {
err11, ok := err1.(*_HandshakeError)
if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) {
if !ok || !(errors.Is(err11.Err, io.ErrClosedPipe /* on Write */) || errors.Is(err11.Err, io.ErrUnexpectedEOF /* on Read */)) {
t.Errorf("handshake peer close: unexpected error: %#v", err1)
}
// XXX same for handshakeServer
// ctx cancel
// XXX same for handshakeServer
p1, p2 = net.Pipe()
ctx, cancel := context.WithCancel(bg)
wg = xsync.NewWorkGroup(ctx)
gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1)
_, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
})
tdelay()
cancel()
......@@ -110,5 +132,4 @@ func TestHandshake(t *testing.T) {
if !ok || !(err11.Err == context.Canceled) {
t.Errorf("handshake cancel: unexpected error: %#v", err1)
}
}
......@@ -39,15 +39,17 @@ type pktBuf struct {
data []byte // whole packet data including all headers
}
// Header returns pointer to packet header.
func (pkt *pktBuf) Header() *proto.PktHeader {
// HeaderN returns pointer to packet header in 'N'-encoding.
func (pkt *pktBuf) Header() *proto.PktHeader { return pkt.HeaderN() } // XXX kill
func (pkt *pktBuf) HeaderN() *proto.PktHeader {
// NOTE no need to check len(.data) < PktHeader:
// .data is always allocated with cap >= PktHeaderLen.
return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0]))
}
// Payload returns []byte representing packet payload.
func (pkt *pktBuf) Payload() []byte {
// PayloadN returns []byte representing packet payload in 'N'-encoding.
func (pkt *pktBuf) Payload() []byte { return pkt.PayloadN() } // XXX kill
func (pkt *pktBuf) PayloadN() []byte {
return pkt.data[proto.PktHeaderLen:]
}
......@@ -87,6 +89,7 @@ func (pkt *pktBuf) String() string {
h := pkt.Header()
s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId))
// XXX encN-specific
msgCode := packed.Ntoh16(h.MsgCode)
msgLen := packed.Ntoh32(h.MsgLen)
data := pkt.Payload()
......@@ -98,7 +101,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := msg.NEOMsgDecode(data)
n, err := encN.NEOMsgDecode(msg, data) // XXX encN hardcoded
if err != nil {
s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data)
} else {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package proto
// runtime glue for msgpack support
import (
"fmt"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
)
// mstructDecodeError represents decode error when decoder was expecting
// tuple<nfield> for structure named path.
type mstructDecodeError struct {
path string // "Type.field.field"
op msgpack.Op // op we got
opOk msgpack.Op // op expected
}
func (e *mstructDecodeError) Error() string {
return fmt.Sprintf("decode: M: struct %s: got opcode %02x; expect %02x", e.path, e.op, e.opOk)
}
// mdecodeErr is called to normilize error when msgp.ReadXXX returns err when decoding path.
func mdecodeErr(path string, err error) error {
if err == msgp.ErrShortBytes {
return ErrDecodeOverflow
}
return &mdecodeError{path, err}
}
type mdecodeError struct {
path string // "Type.field.field"
err error
}
func (e *mdecodeError) Error() string {
return fmt.Sprintf("decode: M: %s: %s", e.path, e.err)
}
// mOpError represents decode error when decoder faces unexpected operation.
type mOpError struct {
op, opOk msgpack.Op // op we got and what was expected
}
func (e *mOpError) Error() string {
return fmt.Sprintf("expected opcode %02x; got %02x", e.opOk, e.op)
}
func mdecodeOpErr(path string, op, opOk msgpack.Op) error {
return mdecodeErr(path+"/op", &mOpError{op, opOk})
}
// mLen8Error represents decode error when decoder faces unexpected length in Bin8.
type mLen8Error struct {
l, lOk byte // len we got and expected
}
func (e *mLen8Error) Error() string {
return fmt.Sprintf("expected length %d; got %d", e.lOk, e.l)
}
func mdecodeLen8Err(path string, l, lOk uint8) error {
return mdecodeErr(path+"/len", &mLen8Error{l, lOk})
}
func mdecodeEnumTypeErr(path string, enumType, enumTypeOk byte) error {
return mdecodeErr(path+"/enumType",
fmt.Errorf("expected %d; got %d", enumTypeOk, enumType))
}
func mdecodeEnumValueErr(path string, v byte) error {
return mdecodeErr(path, fmt.Errorf("invalid enum payload %02x", v))
}
......@@ -32,6 +32,12 @@ import (
"time"
)
// MsgCode returns code the corresponds to type of the message.
// XXX place - ok?
func MsgCode(msg Msg) uint16 {
return msg.neoMsgCode()
}
// MsgType looks up message type by message code.
//
// Nil is returned if message code is not valid.
......
......@@ -84,6 +84,7 @@ const (
Version = 6
// length of packet header
// XXX encN-specific ?
PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr)
// packets larger than PktMaxSize are not allowed.
......@@ -99,6 +100,7 @@ const (
INVALID_OID zodb.Oid = 1<<64 - 1
)
// XXX encN-specific ?
// PktHeader represents header of a raw packet.
//
// A packet contains connection ID and message.
......@@ -110,31 +112,75 @@ type PktHeader struct {
MsgLen packed.BE32 // payload message length (excluding packet header)
}
// Msg is the interface implemented by all NEO messages.
// Msg is the interface representing a NEO message.
type Msg interface {
// marshal/unmarshal into/from wire format:
// NEOMsgCode returns message code needed to be used for particular message type
// neoMsgCode returns message code needed to be used for particular message type
// on the wire.
NEOMsgCode() uint16
neoMsgCode() uint16
// NEOMsgEncodedLen returns how much space is needed to encode current message payload.
NEOMsgEncodedLen() int
// NEOMsgEncode encodes current message state into buf.
// for encoding E:
//
// - neoMsgEncodedLen<E> returns how much space is needed to encode current message payload via E encoding.
//
// - neoMsgEncode<E> encodes current message state into buf via E encoding.
//
// len(buf) must be >= neoMsgEncodedLen().
NEOMsgEncode(buf []byte)
// len(buf) must be >= neoMsgEncodedLen<E>().
//
// - neoMsgDecode<E> decodes data via E encoding into message in-place.
// N encoding (original struct-based encoding)
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
// NEOMsgDecode decodes data into message in-place.
NEOMsgDecode(data []byte) (nread int, err error)
// M encoding (via MessagePack)
neoMsgEncodedLenM() int
neoMsgEncodeM(buf []byte)
neoMsgDecodeM(data []byte) (nread int, err error)
}
// Encoding represents messages encoding.
type Encoding byte
// XXX drop "NEO" prefix?
// NEOMsgEncodedLen returns how much space is needed to encode msg payload via encoding e.
func (e Encoding) NEOMsgEncodedLen(msg Msg) int {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgEncodedLenN()
case 'M': return msg.neoMsgEncodedLenM()
}
}
// NEOMsgEncode encodes msg state into buf via encoding e.
//
// len(buf) must be >= e.NEOMsgEncodedLen(m).
func (e Encoding) NEOMsgEncode(msg Msg, buf []byte) {
switch e {
default: panic("bug")
case 'N': msg.neoMsgEncodeN(buf)
case 'M': msg.neoMsgEncodeM(buf)
}
}
// NEOMsgDecode decodes data via encoding e into msg in-place.
func (e Encoding) NEOMsgDecode(msg Msg, data []byte) (nread int, err error) {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgDecodeN(data)
case 'M': return msg.neoMsgDecodeM(data)
}
}
// ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow
var ErrDecodeOverflow = errors.New("decode: buffer overflow")
// ---- messages ----
//neo:proto enum
type ErrorCode uint32
const (
ACK ErrorCode = iota
......@@ -155,6 +201,7 @@ const (
// XXX move this to neo.clusterState wrapping proto.ClusterState?
//trace:event traceClusterStateChanged(cs *ClusterState)
//neo:proto enum
type ClusterState int8
const (
// The cluster is initially in the RECOVERING state, and it goes back to
......@@ -188,6 +235,7 @@ const (
STOPPING_BACKUP
)
//neo:proto enum
type NodeType int8
const (
MASTER NodeType = iota
......@@ -196,6 +244,7 @@ const (
ADMIN
)
//neo:proto enum
type NodeState int8
const (
UNKNOWN NodeState = iota //short: U // XXX tag prefix name ?
......@@ -204,6 +253,7 @@ const (
PENDING //short: P
)
//neo:proto enum
type CellState int8
const (
// Write-only cell. Last transactions are missing because storage is/was down
......@@ -255,7 +305,7 @@ type Address struct {
}
// NOTE if Host == "" -> Port not added to wire (see py.PAddress):
func (a *Address) neoEncodedLen() int {
func (a *Address) neoEncodedLenN() int {
l := string_neoEncodedLen(a.Host)
if a.Host != "" {
l += 2
......@@ -263,7 +313,7 @@ func (a *Address) neoEncodedLen() int {
return l
}
func (a *Address) neoEncode(b []byte) int {
func (a *Address) neoEncodeN(b []byte) int {
n := string_neoEncode(a.Host, b[0:])
if a.Host != "" {
binary.BigEndian.PutUint16(b[n:], a.Port)
......@@ -272,7 +322,7 @@ func (a *Address) neoEncode(b []byte) int {
return n
}
func (a *Address) neoDecode(b []byte) (uint64, bool) {
func (a *Address) neoDecodeN(b []byte) (uint64, bool) {
n, ok := string_neoDecode(&a.Host, b)
if !ok {
return 0, false
......@@ -295,17 +345,17 @@ type Checksum [20]byte
// PTid is Partition Table identifier.
//
// Zero value means "invalid id" (<-> None in py.PPTID)
// Zero value means "invalid id" (<-> None in py.PPTID) XXX = nil in msgpack
type PTid uint64
// IdTime represents time of identification.
type IdTime float64
func (t IdTime) neoEncodedLen() int {
func (t IdTime) neoEncodedLenN() int {
return 8
}
func (t IdTime) neoEncode(b []byte) int {
func (t IdTime) neoEncodeN(b []byte) int {
// use -inf as value for no data (NaN != NaN -> hard to use NaN in tests)
// NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer
tt := float64(t)
......@@ -316,7 +366,7 @@ func (t IdTime) neoEncode(b []byte) int {
return 8
}
func (t *IdTime) neoDecode(data []byte) (uint64, bool) {
func (t *IdTime) neoDecodeN(data []byte) (uint64, bool) {
if len(data) < 8 {
return 0, false
}
......@@ -438,8 +488,8 @@ type Recovery struct {
type AnswerRecovery struct {
PTid
BackupTid zodb.Tid
TruncateTid zodb.Tid
BackupTid zodb.Tid // XXX nil <-> 0
TruncateTid zodb.Tid // XXX nil <-> 0
}
// Ask the last OID/TID so that a master can initialize its TransactionManager.
......@@ -1199,13 +1249,13 @@ type FlushLog struct {}
// ---- runtime support for protogen and custom codecs ----
// customCodec is the interface that is implemented by types with custom encodings.
// customCodecN is the interface that is implemented by types with custom N encodings.
//
// its semantic is very similar to Msg.
type customCodec interface {
neoEncodedLen() int
neoEncode(buf []byte) (nwrote int)
neoDecode(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
type customCodecN interface {
neoEncodedLenN() int
neoEncodeN(buf []byte) (nwrote int)
neoDecodeN(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
}
func byte2bool(b byte) bool {
......
......@@ -79,31 +79,32 @@ func TestPktHeader(t *testing.T) {
}
// test marshalling for one message type
func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) {
typ := reflect.TypeOf(msg).Elem() // type of *msg
msg2 := reflect.New(typ).Interface().(Msg)
defer func() {
if e := recover(); e != nil {
t.Errorf("%v: panic ↓↓↓:", typ)
t.Errorf("%c/%v: panic ↓↓↓:", enc, typ)
panic(e) // to show traceback
}
}()
// msg.encode() == expected
msgCode := msg.NEOMsgCode()
n := msg.NEOMsgEncodedLen()
msgCode := msg.neoMsgCode()
n := enc.NEOMsgEncodedLen(msg)
msgType := MsgType(msgCode)
if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType)
t.Errorf("%c/%v: msgCode = %v which corresponds to %v", enc, typ, msgCode, msgType)
}
if n != len(encoded) {
t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded))
t.Errorf("%c/%v: encodedLen = %v ; want %v", enc, typ, n, len(encoded))
}
buf := make([]byte, n)
msg.NEOMsgEncode(buf)
enc.NEOMsgEncode(msg, buf)
if string(buf) != encoded {
t.Errorf("%v: encode result unexpected:", typ)
t.Errorf("%c/%v: encode result unexpected:", enc, typ)
t.Errorf("\thave: %s", hexpkg.EncodeToString(buf))
t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded)))
}
......@@ -112,7 +113,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
for l := len(buf) - 1; l >= 0; l-- {
func() {
defer func() {
subj := fmt.Sprintf("%v: encode(buf[:encodedLen-%v])", typ, len(encoded)-l)
subj := fmt.Sprintf("%c/%v: encode(buf[:encodedLen-%v])", enc, typ, len(encoded)-l)
e := recover()
if e == nil {
t.Errorf("%s did not panic", subj)
......@@ -131,29 +132,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
}
}()
msg.NEOMsgEncode(buf[:l])
enc.NEOMsgEncode(msg, buf[:l])
}()
}
// msg.decode() == expected
data := []byte(encoded + "noise")
n, err := msg2.NEOMsgDecode(data)
n, err := enc.NEOMsgDecode(msg2, data)
if err != nil {
t.Errorf("%v: decode error %v", typ, err)
t.Errorf("%c/%v: decode error %v", enc, typ, err)
}
if n != len(encoded) {
t.Errorf("%v: nread = %v ; want %v", typ, n, len(encoded))
t.Errorf("%c/%v: nread = %v ; want %v", enc, typ, n, len(encoded))
}
if !reflect.DeepEqual(msg2, msg) {
t.Errorf("%v: decode result unexpected: %v ; want %v", typ, msg2, msg)
t.Errorf("%c/%v: decode result unexpected: %v ; want %v", enc, typ, msg2, msg)
}
// decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- {
n, err = msg2.NEOMsgDecode(data[:l])
n, err = enc.NEOMsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%v: decode overflow not detected on [:%v]", typ, l)
t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l)
}
}
......@@ -162,14 +163,21 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
// test encoding/decoding of messages
func TestMsgMarshal(t *testing.T) {
var testv = []struct {
msg Msg
encoded string // []byte
msg Msg
encodedN string // []byte
encodedM string // []byte
}{
// empty
{&Ping{}, ""},
{&Ping{},
"",
"\x90",
},
// uint32, string
{&Error{Code: 0x01020304, Message: "hello"}, "\x01\x02\x03\x04\x00\x00\x00\x05hello"},
// uint32(N)/enum(M), string
{&Error{Code: 0x00000045, Message: "hello"},
"\x00\x00\x00\x45\x00\x00\x00\x05hello",
hex("92") + hex("d40045") + "\xc4\x05hello",
},
// Oid, Tid, bool, Checksum, []byte
{&StoreObject{
......@@ -185,7 +193,18 @@ func TestMsgMarshal(t *testing.T) {
hex("01020304050607080a0b0c0d0e0f010200") +
hex("0102030405060708090a0b0c0d0e0f1011121314") +
hex("0000000b") + "hello world" +
hex("0a0b0c0d0e0f01030a0b0c0d0e0f0104")},
hex("0a0b0c0d0e0f01030a0b0c0d0e0f0104"),
// M
hex("97") +
hex("c408") + hex("0102030405060708") +
hex("c408") + hex("0a0b0c0d0e0f0102") +
hex("c2") +
hex("c414") + hex("0102030405060708090a0b0c0d0e0f1011121314") +
hex("c40b") + "hello world" +
hex("c408") + hex("0a0b0c0d0e0f0103") +
hex("c408") + hex("0a0b0c0d0e0f0104"),
},
// PTid, [] (of [] of {UUID, CellState})
{&AnswerPartitionTable{
......@@ -198,12 +217,22 @@ func TestMsgMarshal(t *testing.T) {
},
},
// N
hex("0102030405060708") +
hex("00000022") +
hex("00000003") +
hex("000000020000000b010000001100") +
hex("000000010000000b02") +
hex("000000030000000b030000000f040000001701"),
// M
hex("93") +
hex("cf0102030405060708") +
hex("22") +
hex("93") +
hex("91"+"92"+"920bd40401"+"9211d40400") +
hex("91"+"91"+"920bd40402") +
hex("91"+"93"+"920bd40403"+"920fd40404"+"9217d40401"),
},
// map[Oid]struct {Tid,Tid,bool}
......@@ -219,11 +248,20 @@ func TestMsgMarshal(t *testing.T) {
5: {4, 3, true},
}},
// N
u32(4) +
u64(1) + u64(1) + u64(0) + hex("00") +
u64(2) + u64(7) + u64(1) + hex("01") +
u64(5) + u64(4) + u64(3) + hex("01") +
u64(8) + u64(7) + u64(1) + hex("00"),
// M
hex("91") +
hex("84") +
hex("c408")+u64(1) + hex("93") + hex("c408")+u64(1) + hex("c408")+u64(0) + hex("c2") +
hex("c408")+u64(2) + hex("93") + hex("c408")+u64(7) + hex("c408")+u64(1) + hex("c3") +
hex("c408")+u64(5) + hex("93") + hex("c408")+u64(4) + hex("c408")+u64(3) + hex("c3") +
hex("c408")+u64(8) + hex("93") + hex("c408")+u64(7) + hex("c408")+u64(1) + hex("c2"),
},
// map[uint32]UUID + trailing ...
......@@ -238,41 +276,86 @@ func TestMsgMarshal(t *testing.T) {
MaxTID: 128,
},
// N
u32(4) +
u32(1) + u32(7) +
u32(2) + u32(9) +
u32(4) + u32(17) +
u32(7) + u32(3) +
u64(23) + u64(128),
// M
hex("93") +
hex("84") +
hex("01" + "07") +
hex("02" + "09") +
hex("04" + "11") +
hex("07" + "03") +
hex("c408") + u64(23) +
hex("c408") + u64(128),
},
// uint32, []uint32
{&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}},
// N
u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4),
// M
hex("92") +
hex("07") +
hex("94") +
hex("01030904"),
},
// uint32, Address, string, IdTime
{&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} },
// N
u8(2) + u32(17) + u32(9) +
"localhost" + u16(7777) +
u32(6) + "myname" +
hex("3fbf9add1091c895") +
u32(2) + u32(5)+"room1" + u32(7)+"rack234" +
u32(3) + u32(3)+u32(4)+u32(5),
// M
hex("97") +
hex("d40202") +
hex("11") +
hex("92") + hex("c409")+"localhost" + hex("cd")+u16(7777) +
hex("c406")+"myname" +
hex("cb" + "3fbf9add1091c895") +
hex("92") + hex("c405")+"room1" + hex("c407")+"rack234" +
hex("93") + hex("030405"),
},
// IdTime, empty Address, int32
{&NotifyNodeInformation{1504466245.926185, []NodeInfo{
{CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}},
// N
hex("41d66b15517b469d") + u32(1) +
u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) +
hex("41d66b15517b3d04"),
// M
hex("92") +
hex("cb" + "41d66b15517b469d") +
hex("91") +
hex("95") +
hex("d40202") +
hex("92" + "c400"+"" + "00") +
hex("d2" + "e0000001") +
hex("d40302") +
hex("cb" + "41d66b15517b3d04"),
},
// empty IdTime
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}}, hex("ffffffffffffffff") + hex("00000000")},
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}},
// N
hex("ffffffffffffffff") + hex("00000000"),
// M
hex("92") +
hex("cb" + "fff0000000000000") + // XXX nan/-inf not handled yet
hex("90"),
},
// TODO we need tests for:
// []varsize + trailing
......@@ -280,7 +363,8 @@ func TestMsgMarshal(t *testing.T) {
}
for _, tt := range testv {
testMsgMarshal(t, tt.msg, tt.encoded)
testMsgMarshal(t, 'N', tt.msg, tt.encodedN)
testMsgMarshal(t, 'M', tt.msg, tt.encodedM)
}
}
......@@ -288,18 +372,23 @@ func TestMsgMarshal(t *testing.T) {
// this way we additionally lightly check encode / decode overflow behaviour for all types.
func TestMsgMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range msgTypeRegistry {
// zero-value for a type
msg := reflect.New(typ).Interface().(Msg)
l := msg.NEOMsgEncodedLen()
zerol := make([]byte, l)
// decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := msg.NEOMsgDecode(zerol)
if !(n == l && err == nil) {
t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l)
}
for _, enc := range []Encoding{'N', 'M'} {
// zero-value for a type
msg := reflect.New(typ).Interface().(Msg)
l := enc.NEOMsgEncodedLen(msg)
zerol := make([]byte, l)
if enc != 'N' { // M-encoding of zero-value is not all zeros
enc.NEOMsgEncode(msg, zerol)
}
// decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := enc.NEOMsgDecode(msg, zerol)
if !(n == l && err == nil) {
t.Errorf("%c/%v: zero-decode unexpected: %v, %v ; want %v, nil", enc, typ, n, err, l)
}
testMsgMarshal(t, msg, string(zerol))
testMsgMarshal(t, enc, msg, string(zerol))
}
}
}
......@@ -316,6 +405,8 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
{&AnswerLockedTransactions{}, u32(0x10000000)},
}
enc := Encoding('N') // XXX hardcoded XXX + M-variants with big len?
for _, tt := range testv {
data := []byte(tt.data)
func() {
......@@ -325,7 +416,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
}
}()
n, err := tt.msg.NEOMsgDecode(data)
n, err := enc.NEOMsgDecode(tt.msg, data)
if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data,
n, err, 0, ErrDecodeOverflow)
......
......@@ -25,10 +25,11 @@ NEO. Protocol module. Code generator
This program generates marshalling code for message types defined in proto.go .
For every type 4 methods are generated in accordance with neo.Msg interface:
NEOMsgCode() uint16
NEOMsgEncodedLen() int
NEOMsgEncode(buf []byte)
NEOMsgDecode(data []byte) (nread int, err error)
// XXX update for 'N' and 'M'
neoMsgCode() uint16
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
List of message types is obtained via searching through proto.go AST - looking
for appropriate struct declarations there.
......@@ -40,7 +41,7 @@ maps, ...).
Top-level generation driver is in generateCodecCode(). It accepts type
specification and something that performs actual leaf-nodes code generation
(CodeGenerator interface). There are 3 particular codegenerators implemented -
- sizer, encoder & decoder - to generate each of the needed method functions.
- sizerN, encoderN & decoder - to generate each of the needed method functions. XXX N/M
The structure of whole process is very similar to what would be happening at
runtime if marshalling was reflect based, but statically with go/types we don't
......@@ -77,6 +78,8 @@ import (
"os"
"sort"
"strings"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
)
// parsed & typechecked input
......@@ -116,8 +119,16 @@ func typeName(typ types.Type) string {
return types.TypeString(typ, qf)
}
var neo_customCodec *types.Interface // type of neo.customCodec
var memBuf types.Type // type of mem.Buf
// zodb.Tid and zodb.Oid types
var zodbTid types.Type
var zodbOid types.Type
var neo_customCodecN *types.Interface // type of neo.customCodecN
var memBuf types.Type // type of mem.Buf
// registry of enums
var enumRegistry = map[types.Type]int{} // type -> enum type serial
// bytes.Buffer + bell & whistles
type Buffer struct {
......@@ -181,6 +192,7 @@ func loadPkg(pkgPath string, sources ...string) *types.Package {
type Annotation struct {
typeonly bool
answer bool
enum bool
}
// parse checks doc for specific comment annotations and, if present, loads them.
......@@ -211,6 +223,12 @@ func (a *Annotation) parse(doc *ast.CommentGroup) {
}
a.answer = true
case "enum":
if a.enum {
log.Fatalf("%v: duplicate `enum`", cpos)
}
a.enum = true
default:
log.Fatalf("%v: unknown neo:proto directive %q", cpos, arg)
}
......@@ -243,6 +261,14 @@ func (v BySerial) Len() int { return len(v) }
// ----------------------------------------
func xlookup(pkg *types.Package, name string) types.Object {
obj := pkg.Scope().Lookup(name)
if obj == nil {
log.Fatalf("cannot find `%s.%s`", pkg.Name(), name)
}
return obj
}
func main() {
var err error
......@@ -252,15 +278,12 @@ func main() {
zodbPkg = loadPkg("lab.nexedi.com/kirr/neo/go/zodb", "../../zodb/zodb.go")
protoPkg = loadPkg("lab.nexedi.com/kirr/neo/go/neo/proto", "proto.go")
// extract neo.customCodec
cc := protoPkg.Scope().Lookup("customCodec")
if cc == nil {
log.Fatal("cannot find `customCodec`")
}
// extract neo.customCodecN
cc := xlookup(protoPkg, "customCodecN")
var ok bool
neo_customCodec, ok = cc.Type().Underlying().(*types.Interface)
neo_customCodecN, ok = cc.Type().Underlying().(*types.Interface)
if !ok {
log.Fatal("customCodec is not interface (got %v)", cc.Type())
log.Fatal("customCodecN is not interface (got %v)", cc.Type())
}
// extract mem.Buf
......@@ -282,6 +305,10 @@ func main() {
}
memBuf = __.Type()
// extract zodb.Tid and zodb.Oid
zodbTid = xlookup(zodbPkg, "Tid").Type()
zodbOid = xlookup(zodbPkg, "Oid").Type()
// prologue
f := fileMap["proto.go"]
buf := Buffer{}
......@@ -295,6 +322,9 @@ import (
"reflect"
"sort"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
"lab.nexedi.com/kirr/go123/mem"
"lab.nexedi.com/kirr/neo/go/zodb"
)`)
......@@ -304,6 +334,7 @@ import (
// go over message types declaration and generate marshal code for them
buf.emit("// messages marshalling\n")
msgSerial := 0
enumSerial := 0
for _, decl := range f.Decls {
// we look for types (which can be only under GenDecl)
gendecl, ok := decl.(*ast.GenDecl)
......@@ -324,16 +355,25 @@ import (
typespec := spec.(*ast.TypeSpec) // must be because tok = TYPE
typename := typespec.Name.Name
// we are only interested in struct types
if _, ok := typespec.Type.(*ast.StructType); !ok {
continue
}
// `//neo:proto ...` annotation for this particular type
specAnnotation := declAnnotation // inheriting from decl
specAnnotation.parse(typespec.Doc)
// type only -> don't generate message interface for it
// remember enum types
// FIXME separate dedicated first pass to extract enums first
if specAnnotation.enum {
//typ := typeInfo.Types[typespec.Type].Type
typ := typeInfo.Defs[typespec.Name].Type()
// XXX verify typ is basic int XXX or byte?
enumRegistry[typ]= enumSerial
//fmt.Printf("// enum %s #%d\n", typeName(typ), enumSerial)
enumSerial++
}
// messages are only struct types without typeonly annotation
if _, ok := typespec.Type.(*ast.StructType); !ok {
continue
}
if specAnnotation.typeonly {
continue
}
......@@ -350,13 +390,18 @@ import (
fmt.Fprintf(&buf, "// %s. %s\n\n", msgCode, typename)
buf.emit("func (*%s) NEOMsgCode() uint16 {", typename)
buf.emit("func (*%s) neoMsgCode() uint16 {", typename)
buf.emit("return %s", msgCode)
buf.emit("}\n")
buf.WriteString(generateCodecCode(typespec, &sizer{}))
buf.WriteString(generateCodecCode(typespec, &encoder{}))
buf.WriteString(generateCodecCode(typespec, &decoder{}))
buf.WriteString(generateCodecCode(typespec, &sizerN{}))
buf.WriteString(generateCodecCode(typespec, &encoderN{}))
buf.WriteString(generateCodecCode(typespec, &decoderN{}))
// XXX keep all M routines separate from N for code locality
buf.WriteString(generateCodecCode(typespec, &sizerM{}))
buf.WriteString(generateCodecCode(typespec, &encoderM{}))
buf.WriteString(generateCodecCode(typespec, &decoderM{}))
msgTypeRegistry[msgCode] = typename
msgSerial++
......@@ -382,9 +427,12 @@ import (
// format & output generated code
code, err := format.Source(buf.Bytes())
//code = buf.Bytes()
if true {
if err != nil {
panic(err) // should not happen
}
}
_, err = os.Stdout.Write(code)
if err != nil {
......@@ -394,13 +442,13 @@ import (
// info about encode/decode of a basic fixed-size type
type basicCodec struct {
type basicCodecN struct {
wireSize int
encode string
decode string
}
var basicTypes = map[types.BasicKind]basicCodec{
var basicTypesN = map[types.BasicKind]basicCodecN{
// encode: %v %v will be `data[n:]`, value
// decode: %v will be `data[n:]` (and already made sure data has more enough bytes to read)
types.Bool: {1, "(%v)[0] = bool2byte(%v)", "byte2bool((%v)[0])"},
......@@ -417,18 +465,47 @@ var basicTypes = map[types.BasicKind]basicCodec{
types.Float64: {8, "float64_neoEncode(%v, %v)", "float64_neoDecode(%v)"},
}
// does a type have fixed wire size and, if yes, what it is?
func typeSizeFixed(typ types.Type) (wireSize int, ok bool) {
// does a type have fixed wire size when encoded and, if yes, what it is?
func typeEncodingSizeFixed(encoding byte, typ types.Type) (wireSize int, ok bool) {
return typeRxTxSizeFixed(encoding, typ, false)
}
// does a type always have fixed wire size when we are decoding it from
// a packet received from outside? if yes, what size it is?
func typeDecodingSizeFixed(encoding byte, typ types.Type) (wireSize int, ok bool) {
return typeRxTxSizeFixed(encoding, typ, true)
}
func typeRxTxSizeFixed(encoding byte, typ types.Type, rx bool) (wireSize int, ok bool) {
switch encoding {
default:
panic("bad encoding")
case 'M':
// pass typ through sizerM and see if encoded size is fixed or not
// XXX make something not fixed when rx=true?
s := &sizerM{}
codegenType("x", typ, nil, s)
if !s.size.IsNumeric() { // no symbolic part
return 0, false
}
return s.size.num, true
case 'N':
// implemented below
// XXX also pass through sizerN ?
}
switch u := typ.Underlying().(type) {
case *types.Basic:
basic, ok := basicTypes[u.Kind()]
basic, ok := basicTypesN[u.Kind()]
if ok {
return basic.wireSize, ok
}
case *types.Struct:
for i := 0; i < u.NumFields(); i++ {
size, ok := typeSizeFixed(u.Field(i).Type())
size, ok := typeEncodingSizeFixed(encoding, u.Field(i).Type())
if !ok {
goto notfixed
}
......@@ -438,7 +515,7 @@ func typeSizeFixed(typ types.Type) (wireSize int, ok bool) {
return wireSize, true
case *types.Array:
elemSize, ok := typeSizeFixed(u.Elem())
elemSize, ok := typeEncodingSizeFixed(encoding, u.Elem())
if ok {
return int(u.Len()) * elemSize, ok
}
......@@ -449,17 +526,13 @@ notfixed:
return 0, false
}
// does a type have fixed wire size == 1 ?
func typeSizeFixed1(typ types.Type) bool {
wireSize, _ := typeSizeFixed(typ)
return wireSize == 1
}
// interface of a codegenerator (for sizer/coder/decoder)
// interface of a codegenerator (for sizer/encoder/decoder)
type CodeGenerator interface {
// codegenerator generates code for this encoding
encoding() byte
// tell codegen it should generate code for which type & receiver name
setFunc(recvName, typeName string, typ types.Type)
setFunc(recvName, typeName string, typ types.Type, encoding byte)
// generate code to process a basic fixed type (not string)
// userType is type actually used in source (for which typ is underlying), or nil
......@@ -479,17 +552,36 @@ type CodeGenerator interface {
genArray1(path string, typ *types.Array)
genSlice1(path string, typ types.Type)
// generate code to process header of struct
genStructHead(path string, typ *types.Struct, userType types.Type)
// mem.Buf
genBuf(path string)
/*
// generate code for a custom type which implements its own
// encoding/decoding via implementing neo.customCodec interface.
genCustom(path string)
// encoding/decoding via implementing neo.customCodecN interface.
// XXX move out of common interface?
genCustomN(path string)
*/
// get generated code.
generatedCode() string
}
// interface for codegenerators to inject themselves into {sizer/encoder/decoder}Common.
type CodeGenCustomize interface {
CodeGenerator
// generate code to process slice or map header
genSliceHead(path string, typ *types.Slice, obj types.Object)
genMapHead(path string, typ *types.Map, obj types.Object)
}
// X reports encoding=X
type N struct{}; func (_ *N) encoding() byte { return 'N' }
type M struct{}; func (_ *M) encoding() byte { return 'M' }
// common part of codegenerators
type commonCodeGen struct {
buf Buffer // code is emitted here
......@@ -497,6 +589,7 @@ type commonCodeGen struct {
recvName string // receiver/type for top-level func
typeName string // or empty
typ types.Type
enc byte // encoding variant
varUsed map[string]bool // whether a variable was used
}
......@@ -505,10 +598,11 @@ func (c *commonCodeGen) emit(format string, a ...interface{}) {
c.buf.emit(format, a...)
}
func (c *commonCodeGen) setFunc(recvName, typeName string, typ types.Type) {
func (c *commonCodeGen) setFunc(recvName, typeName string, typ types.Type, encoding byte) {
c.recvName = recvName
c.typeName = typeName
c.typ = typ
c.enc = encoding
}
// get variable for varname (and automatically mark this var as used)
......@@ -520,6 +614,12 @@ func (c *commonCodeGen) var_(varname string) string {
return varname
}
// pathName returns name representing path or assignto.
func (c *commonCodeGen) pathName(path string) string {
// Type, p.f1.f2 -> Type.f1.f2
return strings.Join(append([]string{c.typeName}, strings.Split(path, ".")[1:]...), ".")
}
// symbolic size
// consists of numeric & symbolic expression parts
// size is num + expr1 + expr2 + ...
......@@ -615,22 +715,24 @@ func (o *OverflowCheck) AddExpr(format string, a ...interface{}) {
}
// sizer generates code to compute encoded size of a message
// sizerX generates code to compute X-encoded size of a message.
//
// when type is recursively walked, for every case symbolic size is added appropriately.
// in case when it was needed to generate loops, runtime accumulator variable is additionally used.
// result is: symbolic size + (optionally) runtime accumulator.
type sizer struct {
type sizerCommon struct {
commonCodeGen
size SymSize // currently accumulated size
}
type sizerN struct { sizerCommon; N }
type sizerM struct { sizerCommon; M }
// encoder generates code to encode a message
// encoderX generates code to X-encode a message.
//
// when type is recursively walked, for every case code to update `data[n:]` is generated.
// no overflow checks are generated as by neo.Msg interface provided data
// buffer should have at least payloadLen length returned by NEOMsgEncodedLen()
// (the size computed by sizer).
// buffer should have at least payloadLen length returned by neoMsgEncodedLenX()
// (the size computed by sizerX).
//
// the code emitted looks like:
//
......@@ -638,14 +740,16 @@ type sizer struct {
// encode<typ2>(data[n2:], path2)
// ...
//
// TODO encode have to care in NEOMsgEncode to emit preamble such that bound
// TODO encode have to care in neoMsgEncodeX to emit preamble such that bound
// checking is performed only once (currently compiler emits many of them)
type encoder struct {
type encoderCommon struct {
commonCodeGen
n int // current write position in data
}
type encoderN struct { encoderCommon; N }
type encoderM struct { encoderCommon; M }
// decoder generates code to decode a message
// decoderX generates code to X-decode a message.
//
// when type is recursively walked, for every case code to decode next item from
// `data[n:]` is generated.
......@@ -662,7 +766,7 @@ type encoder struct {
// <assignto1> = decode<typ1>(data[n1:])
// <assignto2> = decode<typ2>(data[n2:])
// ...
type decoder struct {
type decoderCommon struct {
commonCodeGen
// done buffer for generated code
......@@ -677,16 +781,23 @@ type decoder struct {
// current overflow check point
overflow OverflowCheck
}
type decoderN struct { decoderCommon; N }
type decoderM struct { decoderCommon; M }
var _ CodeGenerator = (*sizerN)(nil)
var _ CodeGenerator = (*encoderN)(nil)
var _ CodeGenerator = (*decoderN)(nil)
var _ CodeGenerator = (*sizer)(nil)
var _ CodeGenerator = (*encoder)(nil)
var _ CodeGenerator = (*decoder)(nil)
var _ CodeGenerator = (*sizerM)(nil)
var _ CodeGenerator = (*encoderM)(nil)
var _ CodeGenerator = (*decoderM)(nil)
func (s *sizer) generatedCode() string {
func (s *sizerCommon) generatedCode() string {
code := Buffer{}
// prologue
code.emit("func (%s *%s) NEOMsgEncodedLen() int {", s.recvName, s.typeName)
code.emit("func (%s *%s) neoMsgEncodedLen%c() int {", s.recvName, s.typeName, s.enc)
if s.varUsed["size"] {
code.emit("var %s int", s.var_("size"))
}
......@@ -704,10 +815,10 @@ func (s *sizer) generatedCode() string {
return code.String()
}
func (e *encoder) generatedCode() string {
func (e *encoderCommon) generatedCode() string {
code := Buffer{}
// prologue
code.emit("func (%s *%s) NEOMsgEncode(data []byte) {", e.recvName, e.typeName)
code.emit("func (%s *%s) neoMsgEncode%c(data []byte) {", e.recvName, e.typeName, e.enc)
code.Write(e.buf.Bytes())
......@@ -719,7 +830,7 @@ func (e *encoder) generatedCode() string {
// data = data[n:]
// n = 0
func (d *decoder) resetPos() {
func (d *decoderCommon) resetPos() {
if d.n != 0 {
d.emit("data = data[%v:]", d.n)
d.n = 0
......@@ -743,7 +854,7 @@ func (d *decoder) resetPos() {
// - before reading a variable sized item
// - in the beginning of a loop inside (via overflowCheckLoopEntry)
// - right after loop exit (via overflowCheckLoopExit)
func (d *decoder) overflowCheck() {
func (d *decoderCommon) overflowCheck() {
// nop if we know overflow was already checked
if d.overflow.checked {
return
......@@ -781,7 +892,7 @@ func (d *decoder) overflowCheck() {
}
// overflowCheck variant that should be inserted at the beginning of a loop inside
func (d *decoder) overflowCheckLoopEntry() {
func (d *decoderCommon) overflowCheckLoopEntry() {
if d.overflow.checked {
return
}
......@@ -795,7 +906,7 @@ func (d *decoder) overflowCheckLoopEntry() {
}
// overflowCheck variant that should be inserted right after loop exit
func (d *decoder) overflowCheckLoopExit(loopLenExpr string) {
func (d *decoderCommon) overflowCheckLoopExit(loopLenExpr string) {
if d.overflow.checked {
return
}
......@@ -813,13 +924,13 @@ func (d *decoder) overflowCheckLoopExit(loopLenExpr string) {
func (d *decoder) generatedCode() string {
func (d *decoderCommon) generatedCode() string {
// flush for last overflow check point
d.overflowCheck()
code := Buffer{}
// prologue
code.emit("func (%s *%s) NEOMsgDecode(data []byte) (int, error) {", d.recvName, d.typeName)
code.emit("func (%s *%s) neoMsgDecode%c(data []byte) (int, error) {", d.recvName, d.typeName, d.enc)
if d.varUsed["nread"] {
code.emit("var %v uint64", d.var_("nread"))
}
......@@ -827,6 +938,7 @@ func (d *decoder) generatedCode() string {
code.Write(d.bufDone.Bytes())
// epilogue
// XXX M: return `n + (len0 - len(data))` without nread updates after every decode
retexpr := fmt.Sprintf("%v", d.nread)
if d.varUsed["nread"] {
// casting nread to int is ok even on 32 bit arches:
......@@ -839,7 +951,7 @@ func (d *decoder) generatedCode() string {
// `goto overflow` is not used only for empty structs
// NOTE for >0 check actual X in StdSizes{X} does not particularly matter
if (&types.StdSizes{8, 8}).Sizeof(d.typ) > 0 {
if (&types.StdSizes{8, 8}).Sizeof(d.typ) > 0 || d.enc != 'N' {
code.emit("\noverflow:")
code.emit("return 0, ErrDecodeOverflow")
}
......@@ -848,14 +960,16 @@ func (d *decoder) generatedCode() string {
return code.String()
}
// emit code to size/encode/decode basic fixed type
func (s *sizer) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypes[typ.Kind()]
// ---- basic types ----
// N: emit code to size/encode/decode basic fixed type
func (s *sizerN) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypesN[typ.Kind()]
s.size.Add(basic.wireSize)
}
func (e *encoder) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypes[typ.Kind()]
func (e *encoderN) genBasic(path string, typ *types.Basic, userType types.Type) {
basic := basicTypesN[typ.Kind()]
dataptr := fmt.Sprintf("data[%v:]", e.n)
if userType != typ && userType != nil {
// userType is a named type over some basic, like
......@@ -867,8 +981,8 @@ func (e *encoder) genBasic(path string, typ *types.Basic, userType types.Type) {
e.n += basic.wireSize
}
func (d *decoder) genBasic(assignto string, typ *types.Basic, userType types.Type) {
basic := basicTypes[typ.Kind()]
func (d *decoderN) genBasic(assignto string, typ *types.Basic, userType types.Type) {
basic := basicTypesN[typ.Kind()]
// XXX specifying :hi is not needed - it is only a workaround to help BCE.
// see https://github.com/golang/go/issues/19126#issuecomment-358743715
......@@ -887,33 +1001,311 @@ func (d *decoder) genBasic(assignto string, typ *types.Basic, userType types.Typ
d.overflow.Add(basic.wireSize)
}
// M: XXX
func (s *sizerM) genBasic(path string, typ *types.Basic, userType types.Type) {
// upath casts path into basic type if needed
// e.g. p.x -> int32(p.x) if p.x is custom type with underlying int32
upath := path
if userType.Underlying() != userType {
upath = fmt.Sprintf("%s(%s)", typ.Name(), upath)
}
// zodb.Tid and zodb.Oid are encoded as [8]bin XXX or nil?
if userType == zodbTid || userType == zodbOid {
s.size.Add(1+1+8) // mbin8 + 8 + [8]data
return
}
// enums are encoded as extensions
if _, isEnum := enumRegistry[userType]; isEnum {
s.size.Add(1+1+1) // fixext1 enumType value
return
}
switch typ.Kind() {
case types.Bool: s.size.Add(1) // mfalse|mtrue
case types.Int8: s.size.AddExpr("msgpack.Int8Size(%s)", upath)
case types.Int16: s.size.AddExpr("msgpack.Int16Size(%s)", upath)
case types.Int32: s.size.AddExpr("msgpack.Int32Size(%s)", upath)
case types.Int64: s.size.AddExpr("msgpack.Int64Size(%s)", upath)
case types.Uint8: s.size.AddExpr("msgpack.Uint8Size(%s)", upath)
case types.Uint16: s.size.AddExpr("msgpack.Uint16Size(%s)", upath)
case types.Uint32: s.size.AddExpr("msgpack.Uint32Size(%s)", upath)
case types.Uint64: s.size.AddExpr("msgpack.Uint64Size(%s)", upath)
case types.Float64: s.size.Add(1+8) // mfloat64 + <value64>
}
}
func (e *encoderM) genBasic(path string, typ *types.Basic, userType types.Type) {
// upath casts path into basic type if needed
// e.g. p.x -> int32(p.x) if p.x is custom type with underlying int32
// XXX dup?
upath := path
if userType.Underlying() != userType {
upath = fmt.Sprintf("%s(%s)", typ.Name(), upath)
}
// zodb.Tid and zodb.Oid are encoded as [8]bin XXX or nil ?
if userType == zodbTid || userType == zodbOid {
e.emit("data[%v] = byte(msgpack.Bin8)", e.n); e.n++
e.emit("data[%v] = 8", e.n); e.n++
e.emit("binary.BigEndian.PutUint64(data[%v:], uint64(%s))", e.n, path)
e.n += 8
return
}
// enums are encoded as `fixext1 enumType fixint<value>`
if enum, ok := enumRegistry[userType]; ok {
e.emit("data[%v] = byte(msgpack.FixExt1)", e.n); e.n++
e.emit("data[%v] = %d", e.n, enum); e.n++
e.emit("if !(0 <= %s && %s <= 0x7f) {", path, path) // mposfixint
e.emit(` panic("%s: invalid %s enum value)")`, path, typeName(userType))
e.emit("}")
e.emit("data[%v] = byte(%s)", e.n, path); e.n++
return
}
// mputint emits mput<kind>int<size>(path)
mputint := func(kind string, size int) {
KI := "I" // I or <Kind>i
if kind != "" {
KI = strings.ToUpper(kind) + "i"
}
e.emit("{")
e.emit("n := msgpack.Put%snt%d(data[%v:], %s)", KI, size, e.n, upath)
e.emit("data = data[%v+n:]", e.n)
e.emit("}")
e.n = 0
}
switch typ.Kind() {
case types.Bool:
e.emit("data[%v] = byte(msgpack.Bool(%s))", e.n, path)
e.n += 1
case types.Int8: mputint("", 8)
case types.Int16: mputint("", 16)
case types.Int32: mputint("", 32)
case types.Int64: mputint("", 64)
case types.Uint8: mputint("u", 8)
case types.Uint16: mputint("u", 16)
case types.Uint32: mputint("u", 32)
case types.Uint64: mputint("u", 64)
case types.Float64:
// mfloat64 f64
e.emit("data[%v] = byte(msgpack.Float64)", e.n); e.n++
e.emit("float64_neoEncode(data[%v:], %s)", e.n, upath); e.n += 8
}
}
// decoder expects <op>
// XXX place
func (d *decoderM) expectOp(assignto string, op string) {
d.emit("if op := msgpack.Op(data[%v]); op != %s {", d.n, op); d.n++
d.emit(" return 0, mdecodeOpErr(%q, op, %s)", d.pathName(assignto), op)
d.emit("}")
d.overflow.Add(1)
}
// decoder expects mbin8 l
func (d *decoderM) expectBin8Fix(assignto string, l int) {
d.expectOp(assignto, "msgpack.Bin8")
d.emit("if l := data[%v]; l != %d {", d.n, l); d.n++
d.emit(" return 0, mdecodeLen8Err(%q, l, %d)", d.pathName(assignto), l)
d.emit("}")
d.overflow.Add(1)
}
// decoder expects mfixext1 <enumType>
func (d *decoderM) expectEnum(assignto string, enumType int) {
d.expectOp(assignto, "msgpack.FixExt1")
d.emit("if enumType := data[%v]; enumType != %d {", d.n, enumType); d.n++
d.emit(" return 0, mdecodeEnumTypeErr(%q, enumType, %d)", d.pathName(assignto), enumType)
d.emit("}")
d.overflow.Add(1)
}
func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Type) {
// zodb.Tid and zodb.Oid are encoded as [8]bin
if userType == zodbTid || userType == zodbOid {
d.expectBin8Fix(assignto, 8)
d.emit("%s= %s(binary.BigEndian.Uint64(data[%v:]))", assignto, typeName(userType), d.n)
d.n += 8
d.overflow.Add(8)
return
}
// enums are encoded as `fixext1 enumType fixint<value>`
if enum, ok := enumRegistry[userType]; ok {
d.expectEnum(assignto, enum)
d.emit("{")
d.emit("v := data[%v]", d.n); d.n++
d.emit("if !(0 <= v && v <= 0x7f) {") // mposfixint
d.emit(" return 0, mdecodeEnumValueErr(%q, v)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s(v)", assignto, typeName(userType))
d.emit("}")
d.overflow.Add(1)
return
}
// v represents basic decoded value casted to user type if needed
v := "v"
if userType.Underlying() != userType {
v = fmt.Sprintf("%s(v)", typeName(userType))
}
// mgetint emits assignto = mget<kind>int<size>()
mgetint := func(kind string, size int) {
// we are going to go into msgp - flush previously queued
// overflow checks; put place for next overflow check after
// msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
KI := "I" // I or <Kind>i
if kind != "" {
KI = strings.ToUpper(kind) + "i"
}
d.emit("{")
d.emit("v, tail, err := msgp.Read%snt%dBytes(data)", KI, size)
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
// mgetfloat emits mgetfloat<size>
mgetfloat := func(size int) {
// delving into msgp - flush/prepare next site for overflow check
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size)
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
switch typ.Kind() {
case types.Bool:
// XXX move -> mgetbool ?
d.emit("switch op := msgpack.Op(data[%v]); op {", d.n)
// XXX vvv False also ok
d.emit("default: return 0, mdecodeOpErr(%q, op, msgpack.True)", d.pathName(assignto))
d.emit("case msgpack.True: %s = true", assignto)
d.emit("case msgpack.False: %s = false", assignto)
// XXX also support 0|1 ?
d.emit("}")
d.n++
d.overflow.Add(1)
/*
// -> msgp - flush queued overflow checks; put place for next
// overflow checks after msgp is done.
// XXX better directly compare against mtrue|mfalse|0|1 ?
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("v, tail, err := msgp.ReadBoolBytes(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
*/
case types.Int8: mgetint("", 8)
case types.Int16: mgetint("", 16)
case types.Int32: mgetint("", 32)
case types.Int64: mgetint("", 64)
case types.Uint8: mgetint("u", 8)
case types.Uint16: mgetint("u", 16)
case types.Uint32: mgetint("u", 32)
case types.Uint64: mgetint("u", 64)
case types.Float64: mgetfloat(64)
}
}
// emit code to size/encode/decode array with sizeof(elem)==1
// [len(A)]byte
func (s *sizer) genArray1(path string, typ *types.Array) {
func (s *sizerN) genArray1(path string, typ *types.Array) {
s.size.Add(int(typ.Len()))
}
func (e *encoder) genArray1(path string, typ *types.Array) {
func (e *encoderN) genArray1(path string, typ *types.Array) {
e.emit("copy(data[%v:], %v[:])", e.n, path)
e.n += int(typ.Len())
}
func (d *decoder) genArray1(assignto string, typ *types.Array) {
func (d *decoderN) genArray1(assignto string, typ *types.Array) {
typLen := int(typ.Len())
d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+typLen)
d.n += typLen
d.overflow.Add(typLen)
}
// binX+lenX
// [len(A)]byte
func (s *sizerM) genArray1(path string, typ *types.Array) {
l := int(typ.Len())
s.size.Add(msgpack.BinHeadSize(l))
s.size.Add(l)
}
func (e *encoderM) genArray1(path string, typ *types.Array) {
l := int(typ.Len())
if l > 0xff {
panic("TODO: array1 with > 255 elements")
}
e.emit("data[%v] = byte(msgpack.Bin8)", e.n); e.n++
e.emit("data[%v] = %d", e.n, l); e.n++
e.emit("copy(data[%v:], %v[:])", e.n, path)
e.n += l
}
func (d *decoderM) genArray1(assignto string, typ *types.Array) {
l := int(typ.Len())
if l > 0xff {
panic("TODO: array1 with > 255 elements")
}
d.expectBin8Fix(assignto, l)
d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+l)
d.n += l
d.overflow.Add(l)
}
// emit code to size/encode/decode string or []byte
// len u32
// [len]byte
func (s *sizer) genSlice1(path string, typ types.Type) {
func (s *sizerN) genSlice1(path string, typ types.Type) {
s.size.Add(4)
s.size.AddExpr("len(%s)", path)
}
func (e *encoder) genSlice1(path string, typ types.Type) {
func (e *encoderN) genSlice1(path string, typ types.Type) {
e.emit("{")
e.emit("l := uint32(len(%s))", path)
e.genBasic("l", types.Typ[types.Uint32], nil)
......@@ -924,7 +1316,7 @@ func (e *encoder) genSlice1(path string, typ types.Type) {
e.n = 0
}
func (d *decoder) genSlice1(assignto string, typ types.Type) {
func (d *decoderN) genSlice1(assignto string, typ types.Type) {
d.emit("{")
d.genBasic("l:", types.Typ[types.Uint32], nil)
......@@ -953,17 +1345,75 @@ func (d *decoder) genSlice1(assignto string, typ types.Type) {
d.emit("}")
}
// bin8+len8|bin16+len16|bin32+len32
// [len]byte
func (s *sizerM) genSlice1(path string, typ types.Type) {
// XXX -> mbinsize(len(path)) ?
s.size.AddExpr("msgpack.BinHeadSize(len(%s))", path)
s.size.AddExpr("len(%s)", path)
}
func (e *encoderM) genSlice1(path string, typ types.Type) {
e.emit("{")
e.emit("l := len(%s)", path)
e.emit("n := msgpack.PutBinHead(data[%v:], l)", e.n)
e.emit("data = data[%v+n:]", e.n)
e.emit("copy(data, %v)", path)
e.emit("data = data[l:]")
e.emit("}")
e.n = 0
}
func (d *decoderM) genSlice1(assignto string, typ types.Type) {
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("b, tail, err := msgp.ReadBytesZC(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
// XXX dup wrt decoderN ?
switch t := typ.(type) {
case *types.Basic:
if t.Kind() != types.String {
log.Panicf("bad basic type in slice1: %v", t)
}
d.emit("%v= string(b)", assignto)
case *types.Slice:
// TODO eventually do not copy, but reference data from original
d.emit("%v= make(%v, len(b))", assignto, typeName(typ))
d.emit("copy(%v, b)", assignto)
default:
log.Panicf("bad type in slice1: %v", typ)
}
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
// emit code to size/encode/decode mem.Buf
// same as slice1 but buffer is allocated via mem.BufAlloc
func (s *sizer) genBuf(path string) {
func (s *sizerN) genBuf(path string) {
s.genSlice1(path+".XData()", nil /* typ unused */)
}
func (s *sizerM) genBuf(path string) {
s.genSlice1(path+".XData()", nil /* typ unused */)
}
func (e *encoder) genBuf(path string) {
func (e *encoderN) genBuf(path string) {
e.genSlice1(path+".XData()", nil /* typ unused */)
}
func (e *encoderM) genBuf(path string) {
e.genSlice1(path+".XData()", nil /* typ unused */)
}
func (d *decoder) genBuf(path string) {
func (d *decoderN) genBuf(assignto string) {
d.emit("{")
d.genBasic("l:", types.Typ[types.Uint32], nil)
......@@ -973,21 +1423,54 @@ func (d *decoder) genBuf(path string) {
d.overflow.AddExpr("uint64(l)")
// TODO eventually do not copy but reference original
d.emit("%v= mem.BufAlloc(int(l))", path)
d.emit("copy(%v.Data, data[:l])", path)
d.emit("%v= mem.BufAlloc(int(l))", assignto)
d.emit("copy(%v.Data, data[:l])", assignto)
d.emit("data = data[l:]")
d.emit("}")
}
func (d *decoderM) genBuf(assignto string) {
// XXX dup wrt decoderM.genSlice1 ?
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("{")
d.emit("b, tail, err := msgp.ReadBytesZC(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
// XXX dup wrt decoderN.genBuf ?
// TODO eventually do not copy but reference original
d.emit("%v= mem.BufAlloc(len(b))", assignto)
d.emit("copy(%v.Data, b)", assignto)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
d.emit("}")
}
// emit code to size/encode/decode slice
// len u32
// [len]item
func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) {
func (s *sizerN) genSliceHead(path string, typ *types.Slice, obj types.Object) {
s.size.Add(4)
}
func (s *sizerN) genSlice(path string, typ *types.Slice, obj types.Object) {
s.genSliceCommon(s, path, typ, obj)
}
func (s *sizerCommon) genSliceCommon(xs CodeGenCustomize, path string, typ *types.Slice, obj types.Object) {
xs.genSliceHead(path, typ, obj)
// if size(item)==const - size update in one go
elemSize, ok := typeSizeFixed(typ.Elem())
elemSize, ok := typeEncodingSizeFixed(xs.encoding(), typ.Elem())
if ok {
s.size.AddExpr("len(%v) * %v", path, elemSize)
return
......@@ -999,7 +1482,7 @@ func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) {
s.emit("for i := 0; i < len(%v); i++ {", path)
s.emit("a := &%s[i]", path)
codegenType("(*a)", typ.Elem(), obj, s)
codegenType("(*a)", typ.Elem(), obj, xs)
// merge-in size updates
s.emit("%v += %v", s.var_("size"), s.size.ExprString())
......@@ -1010,29 +1493,45 @@ func (s *sizer) genSlice(path string, typ *types.Slice, obj types.Object) {
s.size = curSize
}
func (e *encoder) genSlice(path string, typ *types.Slice, obj types.Object) {
e.emit("{")
e.emit("l := uint32(len(%s))", path)
e.genBasic("l", types.Typ[types.Uint32], nil)
func (e *encoderN) genSliceHead(path string, typ *types.Slice, obj types.Object) {
e.emit("l := len(%s)", path)
e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int])
e.emit("data = data[%v:]", e.n)
e.n = 0
e.emit("for i := 0; uint32(i) <l; i++ {")
}
func (e *encoderN) genSlice(path string, typ *types.Slice, obj types.Object) {
e.genSliceCommon(e, path, typ, obj)
}
func (e *encoderCommon) genSliceCommon(xe CodeGenCustomize, path string, typ *types.Slice, obj types.Object) {
e.emit("{")
xe.genSliceHead(path, typ, obj)
e.emit("for i := 0; i <l; i++ {")
e.emit("a := &%s[i]", path)
codegenType("(*a)", typ.Elem(), obj, e)
e.emit("data = data[%v:]", e.n)
codegenType("(*a)", typ.Elem(), obj, xe)
if e.n != 0 {
e.emit("data = data[%v:]", e.n)
e.n = 0
}
e.emit("}")
e.emit("}")
e.n = 0
}
func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object) {
d.emit("{")
func (d *decoderN) genSliceHead(assignto string, typ *types.Slice, obj types.Object) {
d.genBasic("l:", types.Typ[types.Uint32], nil)
}
func (d *decoderN) genSlice(assignto string, typ *types.Slice, obj types.Object) {
d.genSliceCommon(d, assignto, typ, obj)
}
func (d *decoderCommon) genSliceCommon(xd CodeGenCustomize, assignto string, typ *types.Slice, obj types.Object) {
d.emit("{")
xd.genSliceHead(assignto, typ, obj)
d.resetPos()
// if size(item)==const - check overflow in one go
elemSize, elemFixed := typeSizeFixed(typ.Elem())
elemSize, elemFixed := typeDecodingSizeFixed(xd.encoding(), typ.Elem())
if elemFixed {
d.overflowCheck()
d.overflow.AddExpr("uint64(l) * %v", elemSize)
......@@ -1045,7 +1544,7 @@ func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object)
d.emit("a := &%s[i]", assignto)
d.overflowCheckLoopEntry()
codegenType("(*a)", typ.Elem(), obj, d)
codegenType("(*a)", typ.Elem(), obj, xd)
d.resetPos()
d.emit("}")
......@@ -1054,27 +1553,72 @@ func (d *decoder) genSlice(assignto string, typ *types.Slice, obj types.Object)
d.emit("}")
}
// fixarray|array16+YYYY|array32+ZZZZ
// [len]item
func (s *sizerM) genSliceHead(path string, typ *types.Slice, obj types.Object) {
s.size.AddExpr("msgpack.ArrayHeadSize(len(%s))", path)
}
func (s *sizerM) genSlice(path string, typ *types.Slice, obj types.Object) {
s.genSliceCommon(s, path, typ, obj)
}
func (e *encoderM) genSliceHead(path string, typ *types.Slice, obj types.Object) {
e.emit("l := len(%s)", path)
e.emit("n := msgpack.PutArrayHead(data[%v:], l)", e.n)
e.emit("data = data[%v+n:]", e.n)
e.n = 0
}
func (e *encoderM) genSlice(path string, typ *types.Slice, obj types.Object) {
e.genSliceCommon(e, path, typ, obj)
}
func (d *decoderM) genSliceHead(assignto string, typ *types.Slice, obj types.Object) {
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("l, tail, err := msgp.ReadArrayHeaderBytes(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
}
func (d *decoderM) genSlice(assignto string, typ *types.Slice, obj types.Object) {
d.genSliceCommon(d, assignto, typ, obj)
}
// generate code to encode/decode map
// len u32
// [len](key, value)
func (s *sizer) genMap(path string, typ *types.Map, obj types.Object) {
keySize, keyFixed := typeSizeFixed(typ.Key())
elemSize, elemFixed := typeSizeFixed(typ.Elem())
func (s *sizerN) genMapHead(path string, typ *types.Map, obj types.Object) {
s.size.Add(4)
}
func (s *sizerN) genMap(path string, typ *types.Map, obj types.Object) {
s.genMapCommon(s, path, typ, obj)
}
func (s *sizerCommon) genMapCommon(xs CodeGenCustomize, path string, typ *types.Map, obj types.Object) {
xs.genMapHead(path, typ, obj)
keySize, keyFixed := typeEncodingSizeFixed(xs.encoding(), typ.Key())
elemSize, elemFixed := typeEncodingSizeFixed(xs.encoding(), typ.Elem())
if keyFixed && elemFixed {
s.size.Add(4)
s.size.AddExpr("len(%v) * %v", path, keySize+elemSize)
return
}
s.size.Add(4)
curSize := s.size
s.size.Reset()
// FIXME for map of map gives ...[key][key] => key -> different variables
s.emit("for key := range %s {", path)
codegenType("key", typ.Key(), obj, s)
codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, s)
codegenType("key", typ.Key(), obj, xs)
codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, xs)
// merge-in size updates
s.emit("%v += %v", s.var_("size"), s.size.ExprString())
......@@ -1085,12 +1629,19 @@ func (s *sizer) genMap(path string, typ *types.Map, obj types.Object) {
s.size = curSize
}
func (e *encoder) genMap(path string, typ *types.Map, obj types.Object) {
e.emit("{")
e.emit("l := uint32(len(%s))", path)
e.genBasic("l", types.Typ[types.Uint32], nil)
func (e *encoderN) genMapHead(path string, typ *types.Map, obj types.Object) {
e.emit("l := len(%s)", path)
e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int])
e.emit("data = data[%v:]", e.n)
e.n = 0
}
func (e *encoderN) genMap(path string, typ *types.Map, obj types.Object) {
e.genMapCommon(e, path, typ, obj)
}
func (e *encoderCommon) genMapCommon(xe CodeGenCustomize, path string, typ *types.Map, obj types.Object) {
e.emit("{")
xe.genMapHead(path, typ, obj)
// output keys in sorted order on the wire
// (easier for debugging & deterministic for testing)
......@@ -1101,23 +1652,32 @@ func (e *encoder) genMap(path string, typ *types.Map, obj types.Object) {
e.emit("sort.Slice(keyv, func (i, j int) bool { return keyv[i] < keyv[j] })")
e.emit("for _, key := range keyv {")
codegenType("key", typ.Key(), obj, e)
codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, e)
e.emit("data = data[%v:]", e.n) // XXX wrt map of map?
codegenType("key", typ.Key(), obj, xe)
codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, xe)
if e.n != 0 {
e.emit("data = data[%v:]", e.n) // XXX wrt map of map?
e.n = 0
}
e.emit("}")
e.emit("}")
e.n = 0
}
func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) {
d.emit("{")
func (d *decoderN) genMapHead(assignto string, typ *types.Map, obj types.Object) {
d.genBasic("l:", types.Typ[types.Uint32], nil)
}
func (d *decoderN) genMap(assignto string, typ *types.Map, obj types.Object) {
d.genMapCommon(d, assignto, typ, obj)
}
func (d *decoderCommon) genMapCommon(xd CodeGenCustomize, assignto string, typ *types.Map, obj types.Object) {
d.emit("{")
xd.genMapHead(assignto, typ, obj)
d.resetPos()
// if size(key,item)==const - check overflow in one go
keySize, keyFixed := typeSizeFixed(typ.Key())
elemSize, elemFixed := typeSizeFixed(typ.Elem())
keySize, keyFixed := typeDecodingSizeFixed(xd.encoding(), typ.Key())
elemSize, elemFixed := typeDecodingSizeFixed(xd.encoding(), typ.Elem())
if keyFixed && elemFixed {
d.overflowCheck()
d.overflow.AddExpr("uint64(l) * %v", keySize+elemSize)
......@@ -1130,18 +1690,19 @@ func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) {
d.emit("for i := 0; uint32(i) < l; i++ {")
d.overflowCheckLoopEntry()
codegenType("key:", typ.Key(), obj, d)
d.emit("var key %s", typeName(typ.Key()))
codegenType("key", typ.Key(), obj, xd)
switch typ.Elem().Underlying().(type) {
// basic types can be directly assigned to map entry
case *types.Basic:
codegenType("m[key]", typ.Elem(), obj, d)
codegenType("m[key]", typ.Elem(), obj, xd)
// otherwise assign via temporary
default:
d.emit("var v %v", typeName(typ.Elem()))
codegenType("v", typ.Elem(), obj, d)
d.emit("m[key] = v")
d.emit("var mv %v", typeName(typ.Elem()))
codegenType("mv", typ.Elem(), obj, xd)
d.emit("m[key] = mv")
}
d.resetPos()
......@@ -1151,27 +1712,65 @@ func (d *decoder) genMap(assignto string, typ *types.Map, obj types.Object) {
d.emit("}")
}
// fixmap|map16+YYYY|map32+ZZZZ
// [len]key/value
func (s *sizerM) genMapHead(path string, typ *types.Map, obj types.Object) {
s.size.AddExpr("msgpack.MapHeadSize(len(%s))", path)
}
func (s *sizerM) genMap(path string, typ *types.Map, obj types.Object) {
s.genMapCommon(s, path, typ, obj)
}
func (e *encoderM) genMapHead(path string, typ *types.Map, obj types.Object) {
e.emit("l := len(%s)", path)
e.emit("n := msgpack.PutMapHead(data[%v:], l)", e.n)
e.emit("data = data[%v+n:]", e.n)
e.n = 0
}
func (e *encoderM) genMap(path string, typ *types.Map, obj types.Object) {
e.genMapCommon(e, path, typ, obj)
}
func (d *decoderM) genMapHead(assignto string, typ *types.Map, obj types.Object) {
// -> msgp: flush queued overflow checks; put place for next overflow
// checks after msgp is done.
d.overflowCheck()
d.resetPos()
defer d.overflowCheck()
d.emit("l, tail, err := msgp.ReadMapHeaderBytes(data)")
d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit("}")
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
d.emit("data = tail")
}
func (d *decoderM) genMap(assignto string, typ *types.Map, obj types.Object) {
d.genMapCommon(d, assignto, typ, obj)
}
// emit code to size/encode/decode custom type
func (s *sizer) genCustom(path string) {
s.size.AddExpr("%s.neoEncodedLen()", path)
func (s *sizerN) genCustomN(path string) {
s.size.AddExpr("%s.neoEncodedLenN()", path)
}
func (e *encoder) genCustom(path string) {
func (e *encoderN) genCustomN(path string) {
e.emit("{")
e.emit("n := %s.neoEncode(data[%v:])", path, e.n)
e.emit("n := %s.neoEncodeN(data[%v:])", path, e.n)
e.emit("data = data[%v + n:]", e.n)
e.emit("}")
e.n = 0
}
func (d *decoder) genCustom(path string) {
func (d *decoderN) genCustomN(path string) {
d.resetPos()
// make sure we check for overflow previous-code before proceeding to custom decoder.
d.overflowCheck()
d.emit("{")
d.emit("n, ok := %s.neoDecode(data)", path)
d.emit("n, ok := %s.neoDecodeN(data)", path)
d.emit("if !ok { goto overflow }")
d.emit("data = data[n:]")
d.emit("%v += n", d.var_("nread"))
......@@ -1182,15 +1781,53 @@ func (d *decoder) genCustom(path string) {
d.overflowCheck()
}
// ---- struct head ----
// N: nothing
func (s *sizerN) genStructHead(path string, typ *types.Struct, userType types.Type) {}
func (e *encoderN) genStructHead(path string, typ *types.Struct, userType types.Type) {}
func (d *decoderN) genStructHead(path string, typ *types.Struct, userType types.Type) {}
// M: array<nfields>
func (s *sizerM) genStructHead(path string, typ *types.Struct, userType types.Type) {
s.size.Add(1) // mfixarray|marray16|marray32
if typ.NumFields() > 0x0f {
panic("TODO: struct with > 15 elements")
}
}
func (e *encoderM) genStructHead(path string, typ *types.Struct, userType types.Type) {
if typ.NumFields() > 0x0f {
panic("TODO: struct with > 15 elements")
}
e.emit("data[%v] = byte(msgpack.FixArray_4 | %d)", e.n, typ.NumFields())
e.n += 1
}
func (d *decoderM) genStructHead(path string, typ *types.Struct, userType types.Type) {
if typ.NumFields() > 0x0f {
panic("TODO: struct with > 15 elements")
}
d.emit("if op, opOk := msgpack.Op(data[%v]), msgpack.FixArray_4 | %d ; op != opOk {", d.n, typ.NumFields())
d.emit("return 0, &mstructDecodeError{%q, op, opOk}", d.pathName(path))
d.emit("}")
d.n += 1
d.overflow.Add(1)
}
// top-level driver for emitting size/encode/decode code for a type
//
// obj is object that uses this type in source program (so in case of an error
// we can point to source location for where it happened)
func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGenerator) {
// neo.customCodec
if types.Implements(typ, neo_customCodec) ||
types.Implements(types.NewPointer(typ), neo_customCodec) {
codegen.genCustom(path)
// neo.customCodecN
ccCustomN, ok := codegen.(interface { genCustomN(path string) })
if ok && (types.Implements(typ, neo_customCodecN) ||
types.Implements(types.NewPointer(typ), neo_customCodecN)) {
ccCustomN.genCustomN(path)
return
}
......@@ -1208,13 +1845,14 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
break
}
_, ok := basicTypes[u.Kind()]
_, ok := basicTypesN[u.Kind()] // ok to check N to see if supported for both N and M
if !ok {
log.Fatalf("%v: %v: basic type %v not supported", pos(obj), obj.Name(), u)
}
codegen.genBasic(path, u, typ)
case *types.Struct:
codegen.genStructHead(path, u, typ)
for i := 0; i < u.NumFields(); i++ {
v := u.Field(i)
codegenType(path+"."+v.Name(), v.Type(), v, codegen)
......@@ -1222,7 +1860,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
case *types.Array:
// [...]byte or [...]uint8 - just straight copy
if typeSizeFixed1(u.Elem()) {
if isByte(u.Elem()) {
codegen.genArray1(path, u)
} else {
var i int64
......@@ -1232,7 +1870,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
}
case *types.Slice:
if typeSizeFixed1(u.Elem()) {
if isByte(u.Elem()) {
codegen.genSlice1(path, u)
} else {
codegen.genSlice(path, u, obj)
......@@ -1242,6 +1880,7 @@ func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGene
codegen.genMap(path, u, obj)
case *types.Pointer:
panic("XXX") // XXX what here?
default:
log.Fatalf("%v: %v has unsupported type %v (%v)", pos(obj),
......@@ -1255,8 +1894,14 @@ func generateCodecCode(typespec *ast.TypeSpec, codegen CodeGenerator) string {
typ := typeInfo.Types[typespec.Type].Type
obj := typeInfo.Defs[typespec.Name]
codegen.setFunc("p", typespec.Name.Name, typ)
codegen.setFunc("p", typespec.Name.Name, typ, codegen.encoding())
codegenType("p", typ, obj, codegen)
return codegen.generatedCode()
}
// isByte returns whether typ represents byte.
func isByte(typ types.Type) bool {
t, ok := typ.(*types.Basic)
return ok && t.Kind() == types.Byte
}
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -387,6 +387,8 @@ func Verify(t *testing.T, f func(*tEnv)) {
// TODO verify M=(go|py) x S=(go|py) x ...
// for now we only verify for all combinations of network
// TODO verify enc=(M|N)
// for all networks
for _, network := range []string{"pipenet", "lonet"} {
opt := tClusterOptions{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment