Commit 6fb1bf26 authored by Roger Peppe's avatar Roger Peppe Committed by Rob Pike

netchan: do not block sends; implement flow control.

When data is received for a channel, but that channel
is not ready to receive it, the central run() loop
is currently blocked, but this can lead to deadlock
and interference between independent channels.
This CL adds an explicit buffer size to netchan
channels (an API change) - the sender will not
send values until the buffer is non empty.

The protocol changes to send ids rather than channel names
because acks can still be sent after a channel is hung up,
we we need an identifier that can be ignored.

R=r, rsc
CC=golang-dev
https://golang.org/cl/2447042
parent 89993544
......@@ -38,21 +38,24 @@ const (
payData // user payload follows
payAck // acknowledgement; no payload
payClosed // channel is now closed
payAckSend // payload has been delivered.
)
// A header is sent as a prefix to every transmission. It will be followed by
// a request structure, an error structure, or an arbitrary user payload structure.
type header struct {
Name string
Id int
PayloadType int
SeqNum int64
}
// Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count
// messages. If count is -1, it means unlimited.
// messages, with space for size buffered values. If count is -1, it means unlimited.
type request struct {
Name string
Count int64
Size int
Dir Dir
}
......@@ -78,7 +81,7 @@ type chanDir struct {
// clients of an exporter and draining outstanding messages.
type clientSet struct {
mu sync.Mutex // protects access to channel and client maps
chans map[string]*chanDir
names map[string]*chanDir
clients map[unackedCounter]bool
}
......@@ -132,7 +135,7 @@ func (cs *clientSet) drain(timeout int64) os.Error {
pending := false
cs.mu.Lock()
// Any messages waiting for a client?
for _, chDir := range cs.chans {
for _, chDir := range cs.names {
if chDir.ch.Len() > 0 {
pending = true
}
......@@ -189,3 +192,134 @@ func (cs *clientSet) sync(timeout int64) os.Error {
}
return nil
}
// A netChan represents a channel imported or exported
// on a single connection. Flow is controlled by the receiving
// side by sending payAckSend messages when values
// are delivered into the local channel.
type netChan struct {
*chanDir
name string
id int
size int // buffer size of channel.
// sender-specific state
ackCh chan bool // buffered with space for all the acks we need
space int // available space.
// receiver-specific state
sendCh chan reflect.Value // buffered channel of values received from other end.
ed *encDec // so that we can send acks.
count int64 // number of values still to receive.
}
// Create a new netChan with the given name (only used for
// messages), id, direction, buffer size, and count.
// The connection to the other side is represented by ed.
func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count int64) *netChan {
c := &netChan{chanDir: ch, name: name, id: id, size: size, ed: ed, count: count}
if c.dir == Send {
c.ackCh = make(chan bool, size)
c.space = size
}
return c
}
// Close the channel.
func (nch *netChan) close() {
if nch.dir == Recv {
if nch.sendCh != nil {
// If the sender goroutine is active, close the channel to it.
// It will close nch.ch when it can.
close(nch.sendCh)
} else {
nch.ch.Close()
}
} else {
nch.ch.Close()
close(nch.ackCh)
}
}
// Send message from remote side to local receiver.
func (nch *netChan) send(val reflect.Value) {
if nch.dir != Recv {
panic("send on wrong direction of channel")
}
if nch.sendCh == nil {
// If possible, do local send directly and ack immediately.
if nch.ch.TrySend(val) {
nch.sendAck()
return
}
// Start sender goroutine to manage delayed delivery of values.
nch.sendCh = make(chan reflect.Value, nch.size)
go nch.sender()
}
if ok := nch.sendCh <- val; !ok {
// TODO: should this be more resilient?
panic("netchan: remote sender sent more values than allowed")
}
}
// sendAck sends an acknowledgment that a message has left
// the channel's buffer. If the messages remaining to be sent
// will fit in the channel's buffer, then we don't
// need to send an ack.
func (nch *netChan) sendAck() {
if nch.count < 0 || nch.count > int64(nch.size) {
nch.ed.encode(&header{Id: nch.id}, payAckSend, nil)
}
if nch.count > 0 {
nch.count--
}
}
// The sender process forwards items from the sending queue
// to the destination channel, acknowledging each item.
func (nch *netChan) sender() {
if nch.dir != Recv {
panic("sender on wrong direction of channel")
}
// When Exporter.Hangup is called, the underlying channel is closed,
// and so we may get a "too many operations on closed channel" error
// if there are outstanding messages in sendCh.
// Make sure that this doesn't panic the whole program.
defer func() {
if r := recover(); r != nil {
// TODO check that r is "too many operations", otherwise re-panic.
}
}()
for v := range nch.sendCh {
nch.ch.Send(v)
nch.sendAck()
}
nch.ch.Close()
}
// Receive value from local side for sending to remote side.
func (nch *netChan) recv() (val reflect.Value, closed bool) {
if nch.dir != Send {
panic("recv on wrong direction of channel")
}
if nch.space == 0 {
// Wait for buffer space.
<-nch.ackCh
nch.space++
}
nch.space--
return nch.ch.Recv(), nch.ch.Closed()
}
// acked is called when the remote side indicates that
// a value has been delivered.
func (nch *netChan) acked() {
if nch.dir != Send {
panic("recv on wrong direction of channel")
}
if ok := nch.ackCh <- true; !ok {
panic("netchan: remote receiver sent too many acks")
// TODO: should this be more resilient?
}
}
......@@ -26,6 +26,7 @@ import (
"net"
"os"
"reflect"
"strconv"
"sync"
)
......@@ -48,11 +49,12 @@ type Exporter struct {
type expClient struct {
*encDec
exp *Exporter
mu sync.Mutex // protects remaining fields
errored bool // client has been sent an error
seqNum int64 // sequences messages sent to client; has value of highest sent
ackNum int64 // highest sequence number acknowledged
seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
chans map[int]*netChan // channels in use by client
mu sync.Mutex // protects remaining fields
errored bool // client has been sent an error
seqNum int64 // sequences messages sent to client; has value of highest sent
ackNum int64 // highest sequence number acknowledged
seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
}
func newClient(exp *Exporter, conn net.Conn) *expClient {
......@@ -61,8 +63,8 @@ func newClient(exp *Exporter, conn net.Conn) *expClient {
client.encDec = newEncDec(conn)
client.seqNum = 0
client.ackNum = 0
client.chans = make(map[int]*netChan)
return client
}
func (client *expClient) sendError(hdr *header, err string) {
......@@ -74,20 +76,33 @@ func (client *expClient) sendError(hdr *header, err string) {
client.mu.Unlock()
}
func (client *expClient) getChan(hdr *header, dir Dir) *chanDir {
func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan {
exp := client.exp
exp.mu.Lock()
ech, ok := exp.chans[hdr.Name]
ech, ok := exp.names[name]
exp.mu.Unlock()
if !ok {
client.sendError(hdr, "no such channel: "+hdr.Name)
client.sendError(hdr, "no such channel: "+name)
return nil
}
if ech.dir != dir {
client.sendError(hdr, "wrong direction for channel: "+hdr.Name)
client.sendError(hdr, "wrong direction for channel: "+name)
return nil
}
nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count)
client.chans[hdr.Id] = nch
return nch
}
func (client *expClient) getChan(hdr *header, dir Dir) *netChan {
nch := client.chans[hdr.Id]
if nch == nil {
return nil
}
return ech
if nch.dir != dir {
client.sendError(hdr, "wrong direction for channel: "+nch.name)
}
return nch
}
// The function run manages sends and receives for a single client. For each
......@@ -113,12 +128,18 @@ func (client *expClient) run() {
expLog("error decoding client request:", err)
break
}
if req.Size < 1 {
panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values")
}
switch req.Dir {
case Recv:
go client.serveRecv(*hdr, req.Count)
// look up channel before calling serveRecv to
// avoid a lock around client.chans.
if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil {
go client.serveRecv(nch, *hdr, req.Count)
}
case Send:
// Request to send is clear as a matter of protocol
// but not actually used by the implementation.
client.newChan(hdr, Recv, req.Name, req.Size, req.Count)
// The actual sends will have payload type payData.
// TODO: manage the count?
default:
......@@ -143,6 +164,10 @@ func (client *expClient) run() {
client.ackNum = hdr.SeqNum
}
client.mu.Unlock()
case payAckSend:
if nch := client.getChan(hdr, Send); nch != nil {
nch.acked()
}
default:
log.Exit("netchan export: unknown payload type", hdr.PayloadType)
}
......@@ -152,14 +177,10 @@ func (client *expClient) run() {
// Send all the data on a single channel to a client asking for a Recv.
// The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveRecv(hdr header, count int64) {
ech := client.getChan(&hdr, Send)
if ech == nil {
return
}
func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
for {
val := ech.ch.Recv()
if ech.ch.Closed() {
val, closed := nch.recv()
if closed {
if err := client.encode(&hdr, payClosed, nil); err != nil {
expLog("error encoding server closed message:", err)
}
......@@ -167,7 +188,7 @@ func (client *expClient) serveRecv(hdr header, count int64) {
}
// We hold the lock during transmission to guarantee messages are
// sent in sequence number order. Also, we increment first so the
// value of client.seqNum is the value of the highest used sequence
// value of client.SeqNum is the value of the highest used sequence
// number, not one beyond.
client.mu.Lock()
client.seqNum++
......@@ -193,27 +214,27 @@ func (client *expClient) serveRecv(hdr header, count int64) {
// Receive and deliver locally one item from a client asking for a Send
// The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveSend(hdr header) {
ech := client.getChan(&hdr, Recv)
if ech == nil {
nch := client.getChan(&hdr, Recv)
if nch == nil {
return
}
// Create a new value for each received item.
val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem())
val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem())
if err := client.decode(val); err != nil {
expLog("value decode:", err)
expLog("value decode:", err, "; type ", nch.ch.Type())
return
}
ech.ch.Send(val)
nch.send(val)
}
// Report that client has closed the channel that is sending to us.
// The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveClosed(hdr header) {
ech := client.getChan(&hdr, Recv)
if ech == nil {
nch := client.getChan(&hdr, Recv)
if nch == nil {
return
}
ech.ch.Close()
nch.close()
}
func (client *expClient) unackedCount() int64 {
......@@ -260,7 +281,7 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) {
e := &Exporter{
listener: listener,
clientSet: &clientSet{
chans: make(map[string]*chanDir),
names: make(map[string]*chanDir),
clients: make(map[unackedCounter]bool),
},
}
......@@ -343,11 +364,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
}
exp.mu.Lock()
defer exp.mu.Unlock()
_, present := exp.chans[name]
_, present := exp.names[name]
if present {
return os.ErrorString("channel name already being exported:" + name)
}
exp.chans[name] = &chanDir{ch, dir}
exp.names[name] = &chanDir{ch, dir}
return nil
}
......@@ -355,10 +376,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
// the channel. Messages in flight for the channel may be dropped.
func (exp *Exporter) Hangup(name string) os.Error {
exp.mu.Lock()
chDir, ok := exp.chans[name]
chDir, ok := exp.names[name]
if ok {
exp.chans[name] = nil, false
exp.names[name] = nil, false
}
// TODO drop all instances of channel from client sets
exp.mu.Unlock()
if !ok {
return os.ErrorString("netchan export: hangup: no such channel: " + name)
......
......@@ -27,8 +27,10 @@ type Importer struct {
*encDec
conn net.Conn
chanLock sync.Mutex // protects access to channel map
chans map[string]*chanDir
names map[string]*netChan
chans map[int]*netChan
errors chan os.Error
maxId int
}
// NewImporter creates a new Importer object to import channels
......@@ -43,7 +45,8 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
imp := new(Importer)
imp.encDec = newEncDec(conn)
imp.conn = conn
imp.chans = make(map[string]*chanDir)
imp.chans = make(map[int]*netChan)
imp.names = make(map[string]*netChan)
imp.errors = make(chan os.Error, 10)
go imp.run()
return imp, nil
......@@ -54,7 +57,7 @@ func (imp *Importer) shutdown() {
imp.chanLock.Lock()
for _, ich := range imp.chans {
if ich.dir == Recv {
ich.ch.Close()
ich.close()
}
}
imp.chanLock.Unlock()
......@@ -95,43 +98,53 @@ func (imp *Importer) run() {
continue // errors are not acknowledged.
}
case payClosed:
ich := imp.getChan(hdr.Name)
if ich != nil {
ich.ch.Close()
nch := imp.getChan(hdr.Id, false)
if nch != nil {
nch.close()
}
continue // closes are not acknowledged.
case payAckSend:
// we can receive spurious acks if the channel is
// hung up, so we ask getChan to ignore any errors.
nch := imp.getChan(hdr.Id, true)
if nch != nil {
nch.acked()
}
continue
default:
impLog("unexpected payload type:", hdr.PayloadType)
return
}
ich := imp.getChan(hdr.Name)
if ich == nil {
nch := imp.getChan(hdr.Id, false)
if nch == nil {
continue
}
if ich.dir != Recv {
if nch.dir != Recv {
impLog("cannot happen: receive from non-Recv channel")
return
}
// Acknowledge receipt
ackHdr.Name = hdr.Name
ackHdr.Id = hdr.Id
ackHdr.SeqNum = hdr.SeqNum
imp.encode(ackHdr, payAck, nil)
// Create a new value for each received item.
value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem())
value := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem())
if e := imp.decode(value); e != nil {
impLog("importer value decode:", e)
return
}
ich.ch.Send(value)
nch.send(value)
}
}
func (imp *Importer) getChan(name string) *chanDir {
func (imp *Importer) getChan(id int, errOk bool) *netChan {
imp.chanLock.Lock()
ich := imp.chans[name]
ich := imp.chans[id]
imp.chanLock.Unlock()
if ich == nil {
impLog("unknown name in netchan request:", name)
if !errOk {
impLog("unknown id in netchan request: ", id)
}
return nil
}
return ich
......@@ -145,17 +158,18 @@ func (imp *Importer) Errors() chan os.Error {
return imp.errors
}
// Import imports a channel of the given type and specified direction.
// Import imports a channel of the given type, size and specified direction.
// It is equivalent to ImportNValues with a count of -1, meaning unbounded.
func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error {
return imp.ImportNValues(name, chT, dir, -1)
func (imp *Importer) Import(name string, chT interface{}, dir Dir, size int) os.Error {
return imp.ImportNValues(name, chT, dir, size, -1)
}
// ImportNValues imports a channel of the given type and specified direction
// and then receives or transmits up to n values on that channel. A value of
// n==-1 implies an unbounded number of values. The channel to be bound to
// the remote site's channel is provided in the call and may be of arbitrary
// channel type.
// ImportNValues imports a channel of the given type and specified
// direction and then receives or transmits up to n values on that
// channel. A value of n==-1 implies an unbounded number of values. The
// channel will have buffer space for size values, or 1 value if size < 1.
// The channel to be bound to the remote site's channel is provided
// in the call and may be of arbitrary channel type.
// Despite the literal signature, the effective signature is
// ImportNValues(name string, chT chan T, dir Dir, n int) os.Error
// Example usage:
......@@ -165,21 +179,28 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error {
// err = imp.ImportNValues("name", ch, Recv, 1)
// if err != nil { log.Exit(err) }
// fmt.Printf("%+v\n", <-ch)
func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) os.Error {
func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size, n int) os.Error {
ch, err := checkChan(chT, dir)
if err != nil {
return err
}
imp.chanLock.Lock()
defer imp.chanLock.Unlock()
_, present := imp.chans[name]
_, present := imp.names[name]
if present {
return os.ErrorString("channel name already being imported:" + name)
}
imp.chans[name] = &chanDir{ch, dir}
if size < 1 {
size = 1
}
id := imp.maxId
imp.maxId++
nch := newNetChan(name, id, &chanDir{ch, dir}, imp.encDec, size, int64(n))
imp.names[name] = nch
imp.chans[id] = nch
// Tell the other side about this channel.
hdr := &header{Name: name}
req := &request{Count: int64(n), Dir: dir}
hdr := &header{Id: id}
req := &request{Name: name, Count: int64(n), Dir: dir, Size: size}
if err = imp.encode(hdr, payRequest, req); err != nil {
impLog("request encode:", err)
return err
......@@ -187,8 +208,8 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
if dir == Send {
go func() {
for i := 0; n == -1 || i < n; i++ {
val := ch.Recv()
if ch.Closed() {
val, closed := nch.recv()
if closed {
if err = imp.encode(hdr, payClosed, nil); err != nil {
impLog("error encoding client closed message:", err)
}
......@@ -208,14 +229,15 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
// the channel. Messages in flight for the channel may be dropped.
func (imp *Importer) Hangup(name string) os.Error {
imp.chanLock.Lock()
chDir, ok := imp.chans[name]
nc, ok := imp.names[name]
if ok {
imp.chans[name] = nil, false
imp.names[name] = nil, false
imp.chans[nc.id] = nil, false
}
imp.chanLock.Unlock()
if !ok {
return os.ErrorString("netchan import: hangup: no such channel: " + name)
}
chDir.ch.Close()
nc.close()
return nil
}
......@@ -15,7 +15,7 @@ const closeCount = 5 // number of items when sender closes early
const base = 23
func exportSend(exp *Exporter, n int, t *testing.T) {
func exportSend(exp *Exporter, n int, t *testing.T, done chan bool) {
ch := make(chan int)
err := exp.Export("exportedSend", ch, Send)
if err != nil {
......@@ -26,6 +26,9 @@ func exportSend(exp *Exporter, n int, t *testing.T) {
ch <- base+i
}
close(ch)
if done != nil {
done <- true
}
}()
}
......@@ -50,9 +53,9 @@ func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) {
}
}
func importSend(imp *Importer, n int, t *testing.T) {
func importSend(imp *Importer, n int, t *testing.T, done chan bool) {
ch := make(chan int)
err := imp.ImportNValues("exportedRecv", ch, Send, count)
err := imp.ImportNValues("exportedRecv", ch, Send, 3, -1)
if err != nil {
t.Fatal("importSend:", err)
}
......@@ -61,12 +64,15 @@ func importSend(imp *Importer, n int, t *testing.T) {
ch <- base+i
}
close(ch)
if done != nil {
done <- true
}
}()
}
func importReceive(imp *Importer, t *testing.T, done chan bool) {
ch := make(chan int)
err := imp.ImportNValues("exportedSend", ch, Recv, count)
err := imp.ImportNValues("exportedSend", ch, Recv, 3, count)
if err != nil {
t.Fatal("importReceive:", err)
}
......@@ -78,7 +84,7 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
}
break
}
if v != 23+i {
if v != base+i {
t.Errorf("importReceive: bad value: expected %d+%d=%d; got %+d", base, i, base+i, v)
}
}
......@@ -96,7 +102,7 @@ func TestExportSendImportReceive(t *testing.T) {
if err != nil {
t.Fatal("new importer:", err)
}
exportSend(exp, count, t)
exportSend(exp, count, t, nil)
importReceive(imp, t, nil)
}
......@@ -116,7 +122,7 @@ func TestExportReceiveImportSend(t *testing.T) {
done <- true
}()
<-expDone
importSend(imp, count, t)
importSend(imp, count, t, nil)
<-done
}
......@@ -129,7 +135,7 @@ func TestClosingExportSendImportReceive(t *testing.T) {
if err != nil {
t.Fatal("new importer:", err)
}
exportSend(exp, closeCount, t)
exportSend(exp, closeCount, t, nil)
importReceive(imp, t, nil)
}
......@@ -149,7 +155,7 @@ func TestClosingImportSendExportReceive(t *testing.T) {
done <- true
}()
<-expDone
importSend(imp, closeCount, t)
importSend(imp, closeCount, t, nil)
<-done
}
......@@ -172,7 +178,7 @@ func TestErrorForIllegalChannel(t *testing.T) {
close(ch)
// Now try to import a different channel.
ch = make(chan int)
err = imp.Import("notAChannel", ch, Recv)
err = imp.Import("notAChannel", ch, Recv, 1)
if err != nil {
t.Fatal("import:", err)
}
......@@ -204,7 +210,7 @@ func TestExportDrain(t *testing.T) {
}
done := make(chan bool)
go func() {
exportSend(exp, closeCount, t)
exportSend(exp, closeCount, t, nil)
done <- true
}()
<-done
......@@ -224,7 +230,7 @@ func TestExportSync(t *testing.T) {
t.Fatal("new importer:", err)
}
done := make(chan bool)
exportSend(exp, closeCount, t)
exportSend(exp, closeCount, t, nil)
go importReceive(imp, t, done)
exp.Sync(0)
<-done
......@@ -248,7 +254,7 @@ func TestExportHangup(t *testing.T) {
}
// Prepare to receive two values. We'll actually deliver only one.
ich := make(chan int)
err = imp.ImportNValues("exportedSend", ich, Recv, 2)
err = imp.ImportNValues("exportedSend", ich, Recv, 1, 2)
if err != nil {
t.Fatal("import exportedSend:", err)
}
......@@ -285,7 +291,7 @@ func TestImportHangup(t *testing.T) {
}
// Prepare to Send two values. We'll actually deliver only one.
ich := make(chan int)
err = imp.ImportNValues("exportedRecv", ich, Send, 2)
err = imp.ImportNValues("exportedRecv", ich, Send, 1, 2)
if err != nil {
t.Fatal("import exportedRecv:", err)
}
......@@ -304,10 +310,70 @@ func TestImportHangup(t *testing.T) {
}
}
// loop back exportedRecv to exportedSend,
// but receive a value from ctlch before starting the loop.
func exportLoopback(exp *Exporter, t *testing.T) {
inch := make(chan int)
if err := exp.Export("exportedRecv", inch, Recv); err != nil {
t.Fatal("exportRecv")
}
outch := make(chan int)
if err := exp.Export("exportedSend", outch, Send); err != nil {
t.Fatal("exportSend")
}
ctlch := make(chan int)
if err := exp.Export("exportedCtl", ctlch, Recv); err != nil {
t.Fatal("exportRecv")
}
go func() {
<-ctlch
for i := 0; i < count; i++ {
x := <-inch
if x != base+i {
t.Errorf("exportLoopback expected %d; got %d", i, x)
}
outch <- x
}
}()
}
// This test checks that channel operations can proceed
// even when other concurrent operations are blocked.
func TestIndependentSends(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
exportLoopback(exp, t)
importSend(imp, count, t, nil)
done := make(chan bool)
go importReceive(imp, t, done)
// wait for export side to try to deliver some values.
time.Sleep(0.25e9)
ctlch := make(chan int)
if err := imp.ImportNValues("exportedCtl", ctlch, Send, 1, 1); err != nil {
t.Fatal("importSend:", err)
}
ctlch <- 0
<-done
}
// This test cross-connects a pair of exporter/importer pairs.
type value struct {
i int
source string
I int
Source string
}
func TestCrossConnect(t *testing.T) {
......@@ -353,13 +419,13 @@ func crossExport(e1, e2 *Exporter, t *testing.T) {
// Import side of cross-traffic.
func crossImport(i1, i2 *Importer, t *testing.T) {
s := make(chan value)
err := i2.Import("exportedReceive", s, Send)
err := i2.Import("exportedReceive", s, Send, 2)
if err != nil {
t.Fatal("import of exportedReceive:", err)
}
r := make(chan value)
err = i1.Import("exportedSend", r, Recv)
err = i1.Import("exportedSend", r, Recv, 2)
if err != nil {
t.Fatal("import of exported Send:", err)
}
......@@ -374,10 +440,76 @@ func crossLoop(name string, s, r chan value, t *testing.T) {
case s <- value{si, name}:
si++
case v := <-r:
if v.i != ri {
if v.I != ri {
t.Errorf("loop: bad value: expected %d, hello; got %+v", ri, v)
}
ri++
}
}
}
const flowCount = 100
// test flow control from exporter to importer.
func TestExportFlowControl(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
sendDone := make(chan bool, 1)
exportSend(exp, flowCount, t, sendDone)
ch := make(chan int)
err = imp.ImportNValues("exportedSend", ch, Recv, 20, -1)
if err != nil {
t.Fatal("importReceive:", err)
}
testFlow(sendDone, ch, flowCount, t)
}
// test flow control from importer to exporter.
func TestImportFlowControl(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
ch := make(chan int)
err = exp.Export("exportedRecv", ch, Recv)
if err != nil {
t.Fatal("importReceive:", err)
}
sendDone := make(chan bool, 1)
importSend(imp, flowCount, t, sendDone)
testFlow(sendDone, ch, flowCount, t)
}
func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) {
go func() {
time.Sleep(1e9)
sendDone <- false
}()
if <-sendDone {
t.Fatal("send did not block")
}
n := 0
for i := range ch {
t.Log("after blocking, got value ", i)
n++
}
if n != N {
t.Fatalf("expected %d values; got %d", N, n)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment