Commit ce877acf authored by David Jakob Fritz's avatar David Jakob Fritz Committed by Rob Pike

netchan: added drain method to importer.

Fixes #1868.

R=golang-dev, r, rsc
CC=golang-dev
https://golang.org/cl/4550093
parent b4ddef3c
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"os" "os"
"reflect" "reflect"
"sync" "sync"
"time"
) )
// Import // Import
...@@ -31,6 +32,9 @@ type Importer struct { ...@@ -31,6 +32,9 @@ type Importer struct {
chans map[int]*netChan chans map[int]*netChan
errors chan os.Error errors chan os.Error
maxId int maxId int
mu sync.Mutex // protects remaining fields
unacked int64 // number of unacknowledged sends.
seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
} }
// NewImporter creates a new Importer object to import a set of channels // NewImporter creates a new Importer object to import a set of channels
...@@ -42,6 +46,7 @@ func NewImporter(conn io.ReadWriter) *Importer { ...@@ -42,6 +46,7 @@ func NewImporter(conn io.ReadWriter) *Importer {
imp.chans = make(map[int]*netChan) imp.chans = make(map[int]*netChan)
imp.names = make(map[string]*netChan) imp.names = make(map[string]*netChan)
imp.errors = make(chan os.Error, 10) imp.errors = make(chan os.Error, 10)
imp.unacked = 0
go imp.run() go imp.run()
return imp return imp
} }
...@@ -80,8 +85,10 @@ func (imp *Importer) run() { ...@@ -80,8 +85,10 @@ func (imp *Importer) run() {
for { for {
*hdr = header{} *hdr = header{}
if e := imp.decode(hdrValue); e != nil { if e := imp.decode(hdrValue); e != nil {
impLog("header:", e) if e != os.EOF {
imp.shutdown() impLog("header:", e)
imp.shutdown()
}
return return
} }
switch hdr.PayloadType { switch hdr.PayloadType {
...@@ -114,6 +121,9 @@ func (imp *Importer) run() { ...@@ -114,6 +121,9 @@ func (imp *Importer) run() {
nch := imp.getChan(hdr.Id, true) nch := imp.getChan(hdr.Id, true)
if nch != nil { if nch != nil {
nch.acked() nch.acked()
imp.mu.Lock()
imp.unacked--
imp.mu.Unlock()
} }
continue continue
default: default:
...@@ -220,10 +230,17 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size, ...@@ -220,10 +230,17 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size,
} }
return return
} }
// We hold the lock during transmission to guarantee messages are
// sent in order.
imp.mu.Lock()
imp.unacked++
imp.seqLock.Lock()
imp.mu.Unlock()
if err = imp.encode(hdr, payData, val.Interface()); err != nil { if err = imp.encode(hdr, payData, val.Interface()); err != nil {
impLog("error encoding client send:", err) impLog("error encoding client send:", err)
return return
} }
imp.seqLock.Unlock()
} }
}() }()
} }
...@@ -244,3 +261,27 @@ func (imp *Importer) Hangup(name string) os.Error { ...@@ -244,3 +261,27 @@ func (imp *Importer) Hangup(name string) os.Error {
nc.close() nc.close()
return nil return nil
} }
func (imp *Importer) unackedCount() int64 {
imp.mu.Lock()
n := imp.unacked
imp.mu.Unlock()
return n
}
// Drain waits until all messages sent from this exporter/importer, including
// those not yet sent to any server and possibly including those sent while
// Drain was executing, have been received by the exporter. In short, it
// waits until all the importer's messages have been received.
// If the timeout (measured in nanoseconds) is positive and Drain takes
// longer than that to complete, an error is returned.
func (imp *Importer) Drain(timeout int64) os.Error {
startTime := time.Nanoseconds()
for imp.unackedCount() > 0 {
if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
return os.ErrorString("timeout")
}
time.Sleep(100 * 1e6)
}
return nil
}
...@@ -178,6 +178,16 @@ func TestExportDrain(t *testing.T) { ...@@ -178,6 +178,16 @@ func TestExportDrain(t *testing.T) {
<-done <-done
} }
// Not a great test but it does at least invoke Drain.
func TestImportDrain(t *testing.T) {
exp, imp := pair(t)
expDone := make(chan bool)
go exportReceive(exp, t, expDone)
<-expDone
importSend(imp, closeCount, t, nil)
imp.Drain(0)
}
// Not a great test but it does at least invoke Sync. // Not a great test but it does at least invoke Sync.
func TestExportSync(t *testing.T) { func TestExportSync(t *testing.T) {
exp, imp := pair(t) exp, imp := pair(t)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment