Commit 3dab1cf0 authored by Kirill Smelkov's avatar Kirill Smelkov

.

parent 8d0a1469
...@@ -43,7 +43,7 @@ import ( ...@@ -43,7 +43,7 @@ import (
// send/receive exchange will be happening in between those 2 connections. // send/receive exchange will be happening in between those 2 connections.
// //
// For a node to be able to accept new incoming connection it has to have // For a node to be able to accept new incoming connection it has to have
// "server" role - see NewNodeLink() for details. // "server" role - see newNodeLink() for details.
// //
// A NodeLink has to be explicitly closed, once it is no longer needed. // A NodeLink has to be explicitly closed, once it is no longer needed.
// //
...@@ -110,7 +110,7 @@ const ( ...@@ -110,7 +110,7 @@ const (
linkFlagsMask LinkRole = (1<<32 - 1) << 16 linkFlagsMask LinkRole = (1<<32 - 1) << 16
) )
// NewNodeLink makes a new NodeLink from already established net.Conn // newNodeLink makes a new NodeLink from already established net.Conn
// //
// Role specifies how to treat our role on the link - either as client or // Role specifies how to treat our role on the link - either as client or
// server. The difference in between client and server roles are in: // server. The difference in between client and server roles are in:
...@@ -124,7 +124,10 @@ const ( ...@@ -124,7 +124,10 @@ const (
// //
// Usually server role should be used for connections created via // Usually server role should be used for connections created via
// net.Listen/net.Accept and client role for connections created via net.Dial. // net.Listen/net.Accept and client role for connections created via net.Dial.
func NewNodeLink(conn net.Conn, role LinkRole) *NodeLink { //
// Though it is possible to wrap just-established raw connection into NodeLink,
// users should always use Handshake which performs protocol handshaking first.
func newNodeLink(conn net.Conn, role LinkRole) *NodeLink {
var nextConnId uint32 var nextConnId uint32
var acceptq chan *Conn var acceptq chan *Conn
switch role&^linkFlagsMask { switch role&^linkFlagsMask {
...@@ -217,6 +220,7 @@ func (nl *NodeLink) shutdown() { ...@@ -217,6 +220,7 @@ func (nl *NodeLink) shutdown() {
// Close closes node-node link. // Close closes node-node link.
// All blocking operations - Accept and IO on associated connections // All blocking operations - Accept and IO on associated connections
// established over node link - are automatically interrupted with an error. // established over node link - are automatically interrupted with an error.
// Underlying raw connection is closed.
func (nl *NodeLink) Close() error { func (nl *NodeLink) Close() error {
atomic.StoreUint32(&nl.closed, 1) atomic.StoreUint32(&nl.closed, 1)
nl.shutdown() nl.shutdown()
...@@ -365,7 +369,7 @@ func (nl *NodeLink) serveRecv() { ...@@ -365,7 +369,7 @@ func (nl *NodeLink) serveRecv() {
// keep connMu locked until here: so that ^^^ `conn.rxq <- pkt` can be // keep connMu locked until here: so that ^^^ `conn.rxq <- pkt` can be
// sure conn stays not down e.g. closed by Conn.Close or NodeLink.shutdown // sure conn stays not down e.g. closed by Conn.Close or NodeLink.shutdown
// //
// XXX try to release connMu eariler - before `rxq <- pkt` // XXX try to release connMu earlier - before `rxq <- pkt`
nl.connMu.Unlock() nl.connMu.Unlock()
if accept { if accept {
...@@ -551,22 +555,37 @@ func (nl *NodeLink) recvPkt() (*PktBuf, error) { ...@@ -551,22 +555,37 @@ func (nl *NodeLink) recvPkt() (*PktBuf, error) {
// ---- Handshake ---- // ---- Handshake ----
// Handshake performs NEO protocol handshake just after 2 nodes are connected // Handshake performs NEO protocol handshake just after raw connection between 2 nodes was established
func Handshake(conn net.Conn) error { // On success raw connection is returned wrapped into NodeLink
return handshake(conn, PROTOCOL_VERSION) // On error raw connection is closed
func Handshake(ctx context.Context, conn net.Conn, role LinkRole) (nl *NodeLink, err error) {
err = handshake(ctx, conn, PROTOCOL_VERSION)
if err != nil {
return nil, err
}
// handshake ok -> NodeLink
return newNodeLink(conn, role), nil
} }
func handshake(conn net.Conn, version uint32) error { // handshake is worker for Handshake
func handshake(ctx context.Context, conn net.Conn, version uint32) (err error) {
errch := make(chan error, 2) errch := make(chan error, 2)
// tx handshake word
txWg := sync.WaitGroup{}
txWg.Add(1)
go func() { go func() {
var b [4]byte var b [4]byte
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ? binary.BigEndian.PutUint32(b[:], version /*+ 33*/) // XXX -> hton32 ?
_, err := conn.Write(b[:]) _, err := conn.Write(b[:])
// XXX EOF -> ErrUnexpectedEOF ? // XXX EOF -> ErrUnexpectedEOF ?
errch <- err errch <- err
txWg.Done()
}() }()
// rx handshake word
go func() { go func() {
var b [4]byte var b [4]byte
_, err := io.ReadFull(conn, b[:]) _, err := io.ReadFull(conn, b[:])
...@@ -582,13 +601,39 @@ func handshake(conn net.Conn, version uint32) error { ...@@ -582,13 +601,39 @@ func handshake(conn net.Conn, version uint32) error {
errch <- err errch <- err
}() }()
connClosed := false
defer func() {
// make sure our version is always sent on the wire, if possible,
// so that peer does not see just closed connection when on rx we see version mismatch
//
// NOTE if cancelled tx goroutine will wake up without delay
txWg.Wait()
// don't forget to close conn if returning with error + add handshake err context
if err != nil {
err = &HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err}
if !connClosed {
conn.Close()
}
}
}()
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := <-errch select {
case <-ctx.Done():
conn.Close() // interrupt IO
connClosed = true
return ctx.Err()
case err = <-errch:
if err != nil { if err != nil {
return &HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err} return err
}
} }
} }
// handshaked ok
return nil return nil
} }
...@@ -603,7 +648,8 @@ func (e *HandshakeError) Error() string { ...@@ -603,7 +648,8 @@ func (e *HandshakeError) Error() string {
return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error()) return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error())
} }
// ---- for convenience: Dial/Listen ----
// ---- for convenience: Dial ----
// Dial connects to address on named network and wrap the connection as NodeLink // Dial connects to address on named network and wrap the connection as NodeLink
// TODO +tls.Config // TODO +tls.Config
...@@ -614,33 +660,11 @@ func Dial(ctx context.Context, network, address string) (nl *NodeLink, err error ...@@ -614,33 +660,11 @@ func Dial(ctx context.Context, network, address string) (nl *NodeLink, err error
return nil, err return nil, err
} }
// do the handshake. don't forget to close peerConn if we return with an error return Handshake(ctx, peerConn, LinkClient)
defer func() {
if err != nil {
peerConn.Close()
}
}()
errch := make(chan error)
go func() {
errch <- Handshake(peerConn)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err = <-errch:
if err != nil {
return nil, &HandshakeError{peerConn.LocalAddr(), peerConn.RemoteAddr(), err}
}
}
// handshake ok -> NodeLink ready
return NewNodeLink(peerConn, LinkClient), nil
} }
// like net.Listener but Accept returns net.Conn wrapped in NodeLink /* TODO not needed -> goes away
// Listener is like net.Listener but Accept returns net.Conn wrapped in NodeLink and handshaked XXX
type Listener struct { type Listener struct {
net.Listener net.Listener
} }
...@@ -651,7 +675,7 @@ func (l *Listener) Accept() (*NodeLink, error) { ...@@ -651,7 +675,7 @@ func (l *Listener) Accept() (*NodeLink, error) {
return nil, err return nil, err
} }
err = Handshake(peerConn) err = Handshake(peerConn) // FIXME blocking - not good - blocks further Accepts
if err != nil { if err != nil {
peerConn.Close() peerConn.Close()
return nil, err return nil, err
...@@ -669,6 +693,7 @@ func Listen(network, laddr string) (*Listener, error) { ...@@ -669,6 +693,7 @@ func Listen(network, laddr string) (*Listener, error) {
} }
return &Listener{l}, nil return &Listener{l}, nil
} }
*/
......
...@@ -99,8 +99,8 @@ func xwait(w interface { Wait() error }) { ...@@ -99,8 +99,8 @@ func xwait(w interface { Wait() error }) {
exc.Raiseif(err) exc.Raiseif(err)
} }
func xhandshake(c net.Conn, version uint32) { func xhandshake(ctx context.Context, c net.Conn, version uint32) {
err := handshake(c, version) err := handshake(ctx, c, version)
exc.Raiseif(err) exc.Raiseif(err)
} }
...@@ -157,8 +157,8 @@ func tdelay() { ...@@ -157,8 +157,8 @@ func tdelay() {
// create NodeLinks connected via net.Pipe // create NodeLinks connected via net.Pipe
func _nodeLinkPipe(flags1, flags2 LinkRole) (nl1, nl2 *NodeLink) { func _nodeLinkPipe(flags1, flags2 LinkRole) (nl1, nl2 *NodeLink) {
node1, node2 := net.Pipe() node1, node2 := net.Pipe()
nl1 = NewNodeLink(node1, LinkClient | flags1) nl1 = newNodeLink(node1, LinkClient | flags1)
nl2 = NewNodeLink(node2, LinkServer | flags2) nl2 = newNodeLink(node2, LinkServer | flags2)
return nl1, nl2 return nl1, nl2
} }
...@@ -538,14 +538,15 @@ func TestNodeLink(t *testing.T) { ...@@ -538,14 +538,15 @@ func TestNodeLink(t *testing.T) {
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
bg := context.Background()
// handshake ok // handshake ok
p1, p2 := net.Pipe() p1, p2 := net.Pipe()
wg := WorkGroup() wg := WorkGroup()
wg.Gox(func() { wg.Gox(func() {
xhandshake(p1, 1) xhandshake(bg, p1, 1)
}) })
wg.Gox(func() { wg.Gox(func() {
xhandshake(p2, 1) xhandshake(bg, p2, 1)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
...@@ -556,10 +557,10 @@ func TestHandshake(t *testing.T) { ...@@ -556,10 +557,10 @@ func TestHandshake(t *testing.T) {
var err1, err2 error var err1, err2 error
wg = WorkGroup() wg = WorkGroup()
wg.Gox(func() { wg.Gox(func() {
err1 = handshake(p1, 1) err1 = handshake(bg, p1, 1)
}) })
wg.Gox(func() { wg.Gox(func() {
err2 = handshake(p2, 2) err2 = handshake(bg, p2, 2)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
...@@ -580,7 +581,7 @@ func TestHandshake(t *testing.T) { ...@@ -580,7 +581,7 @@ func TestHandshake(t *testing.T) {
err1, err2 = nil, nil err1, err2 = nil, nil
wg = WorkGroup() wg = WorkGroup()
wg.Gox(func() { wg.Gox(func() {
err1 = handshake(p1, 1) err1 = handshake(bg, p1, 1)
}) })
wg.Gox(func() { wg.Gox(func() {
xclose(p2) xclose(p2)
...@@ -593,4 +594,23 @@ func TestHandshake(t *testing.T) { ...@@ -593,4 +594,23 @@ func TestHandshake(t *testing.T) {
if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) { if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) {
t.Errorf("handshake peer close: unexpected error: %#v", err1) t.Errorf("handshake peer close: unexpected error: %#v", err1)
} }
// ctx cancel
p1, p2 = net.Pipe()
ctx, cancel := context.WithCancel(bg)
wg.Gox(func() {
err1 = handshake(ctx, p1, 1)
})
tdelay()
cancel()
xwait(wg)
xclose(p1)
xclose(p2)
err11, ok = err1.(*HandshakeError)
if !ok || !(err11.Err == context.Canceled) {
t.Errorf("handshake cancel: unexpected error: %#v", err1)
}
} }
...@@ -21,6 +21,7 @@ package neo ...@@ -21,6 +21,7 @@ package neo
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"reflect" "reflect"
) )
...@@ -36,7 +37,7 @@ type Server interface { ...@@ -36,7 +37,7 @@ type Server interface {
// - for every accepted connection spawn srv.ServeLink() in separate goroutine. // - for every accepted connection spawn srv.ServeLink() in separate goroutine.
// //
// the listener is closed when Serve returns. // the listener is closed when Serve returns.
func Serve(ctx context.Context, l *Listener, srv Server) error { func Serve(ctx context.Context, l net.Listener, srv Server) error {
fmt.Printf("xxx: serving on %s ...\n", l.Addr()) // XXX 'xxx' -> ? fmt.Printf("xxx: serving on %s ...\n", l.Addr()) // XXX 'xxx' -> ?
// close listener when either cancelling or returning (e.g. due to an error) // close listener when either cancelling or returning (e.g. due to an error)
...@@ -53,23 +54,30 @@ func Serve(ctx context.Context, l *Listener, srv Server) error { ...@@ -53,23 +54,30 @@ func Serve(ctx context.Context, l *Listener, srv Server) error {
l.Close() // XXX err l.Close() // XXX err
}() }()
// main Accept -> ServeLink loop // main Accept -> Handshake -> ServeLink loop
for { for {
link, err := l.Accept() peerConn, err := l.Accept()
if err != nil { if err != nil {
// TODO err == closed <-> ctx was cancelled // TODO err == closed <-> ctx was cancelled
// TODO err -> net.Error && .Temporary() -> some throttling // TODO err -> net.Error && .Temporary() -> some throttling
return err return err
} }
go srv.ServeLink(ctx, link) go func() {
link, err := Handshake(ctx, peerConn, LinkServer)
if err != nil {
fmt.Printf("xxx: %s\n", err)
return
}
srv.ServeLink(ctx, link)
}()
} }
} }
// ListenAndServe listens on network address and then calls Serve to handle incoming connections // ListenAndServe listens on network address and then calls Serve to handle incoming connections
// XXX split -> separate Listen() & Serve() // XXX split -> separate Listen() & Serve()
func ListenAndServe(ctx context.Context, net_, laddr string, srv Server) error { func ListenAndServe(ctx context.Context, net_, laddr string, srv Server) error {
l, err := Listen(net_, laddr) l, err := net.Listen(net_, laddr)
if err != nil { if err != nil {
return err return err
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment