Commit 8e69d43b authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: add Buffers type, do writev on unix

No fast path currently for solaris, windows, nacl, plan9.

Fixes #13451

Change-Id: I24b3233a2e3a57fc6445e276a5c0d7b097884007
Reviewed-on: https://go-review.googlesource.com/29951Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent ffd1c781
......@@ -29,6 +29,9 @@ type netFD struct {
laddr Addr
raddr Addr
// writev cache.
iovecs *[]syscall.Iovec
// wait server
pd pollDesc
}
......
......@@ -604,3 +604,66 @@ func acquireThread() {
func releaseThread() {
<-threadLimit
}
// buffersWriter is the interface implemented by Conns that support a
// "writev"-like batch write optimization.
// writeBuffers should fully consume and write all chunks from the
// provided Buffers, else it should report a non-nil error.
type buffersWriter interface {
writeBuffers(*Buffers) (int64, error)
}
var testHookDidWritev = func(wrote int) {}
// Buffers contains zero or more runs of bytes to write.
//
// On certain machines, for certain types of connections, this is
// optimized into an OS-specific batch write operation (such as
// "writev").
type Buffers [][]byte
var (
_ io.WriterTo = (*Buffers)(nil)
_ io.Reader = (*Buffers)(nil)
)
func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
if wv, ok := w.(buffersWriter); ok {
return wv.writeBuffers(v)
}
for _, b := range *v {
nb, err := w.Write(b)
n += int64(nb)
if err != nil {
v.consume(n)
return n, err
}
}
v.consume(n)
return n, nil
}
func (v *Buffers) Read(p []byte) (n int, err error) {
for len(p) > 0 && len(*v) > 0 {
n0 := copy(p, (*v)[0])
v.consume(int64(n0))
p = p[n0:]
n += n0
}
if len(*v) == 0 {
err = io.EOF
}
return
}
func (v *Buffers) consume(n int64) {
for len(*v) > 0 {
ln0 := int64(len((*v)[0]))
if ln0 > n {
(*v)[0] = (*v)[0][n:]
return
}
n -= ln0
*v = (*v)[1:]
}
}
......@@ -414,3 +414,38 @@ func TestZeroByteRead(t *testing.T) {
}
}
}
// withTCPConnPair sets up a TCP connection between two peers, then
// runs peer1 and peer2 concurrently. withTCPConnPair returns when
// both have completed.
func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
ln, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
errc := make(chan error, 2)
go func() {
c1, err := ln.Accept()
if err != nil {
errc <- err
return
}
defer c1.Close()
errc <- peer1(c1.(*TCPConn))
}()
go func() {
c2, err := Dial("tcp", ln.Addr().String())
if err != nil {
errc <- err
return
}
defer c2.Close()
errc <- peer2(c2.(*TCPConn))
}()
for i := 0; i < 2; i++ {
if err := <-errc; err != nil {
t.Fatal(err)
}
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"reflect"
"runtime"
"sync"
"testing"
)
func TestBuffers_read(t *testing.T) {
const story = "once upon a time in Gopherland ... "
buffers := Buffers{
[]byte("once "),
[]byte("upon "),
[]byte("a "),
[]byte("time "),
[]byte("in "),
[]byte("Gopherland ... "),
}
got, err := ioutil.ReadAll(&buffers)
if err != nil {
t.Fatal(err)
}
if string(got) != story {
t.Errorf("read %q; want %q", got, story)
}
if len(buffers) != 0 {
t.Errorf("len(buffers) = %d; want 0", len(buffers))
}
}
func TestBuffers_consume(t *testing.T) {
tests := []struct {
in Buffers
consume int64
want Buffers
}{
{
in: Buffers{[]byte("foo"), []byte("bar")},
consume: 0,
want: Buffers{[]byte("foo"), []byte("bar")},
},
{
in: Buffers{[]byte("foo"), []byte("bar")},
consume: 2,
want: Buffers{[]byte("o"), []byte("bar")},
},
{
in: Buffers{[]byte("foo"), []byte("bar")},
consume: 3,
want: Buffers{[]byte("bar")},
},
{
in: Buffers{[]byte("foo"), []byte("bar")},
consume: 4,
want: Buffers{[]byte("ar")},
},
{
in: Buffers{nil, nil, nil, []byte("bar")},
consume: 1,
want: Buffers{[]byte("ar")},
},
{
in: Buffers{nil, nil, nil, []byte("foo")},
consume: 0,
want: Buffers{[]byte("foo")},
},
{
in: Buffers{nil, nil, nil},
consume: 0,
want: Buffers{},
},
}
for i, tt := range tests {
in := tt.in
in.consume(tt.consume)
if !reflect.DeepEqual(in, tt.want) {
t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
}
}
}
func TestBuffers_WriteTo(t *testing.T) {
for _, name := range []string{"WriteTo", "Copy"} {
for _, size := range []int{0, 10, 1023, 1024, 1025} {
t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
testBuffer_writeTo(t, size, name == "Copy")
})
}
}
}
func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
oldHook := testHookDidWritev
defer func() { testHookDidWritev = oldHook }()
var writeLog struct {
sync.Mutex
log []int
}
testHookDidWritev = func(size int) {
writeLog.Lock()
writeLog.log = append(writeLog.log, size)
writeLog.Unlock()
}
var want bytes.Buffer
for i := 0; i < chunks; i++ {
want.WriteByte(byte(i))
}
withTCPConnPair(t, func(c *TCPConn) error {
buffers := make(Buffers, chunks)
for i := range buffers {
buffers[i] = want.Bytes()[i : i+1]
}
var n int64
var err error
if useCopy {
n, err = io.Copy(c, &buffers)
} else {
n, err = buffers.WriteTo(c)
}
if err != nil {
return err
}
if len(buffers) != 0 {
return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
}
if n != int64(want.Len()) {
return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
}
return nil
}, func(c *TCPConn) error {
all, err := ioutil.ReadAll(c)
if !bytes.Equal(all, want.Bytes()) || err != nil {
return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
}
writeLog.Lock() // no need to unlock
var gotSum int
for _, v := range writeLog.log {
gotSum += v
}
var wantSum int
var wantMinCalls int
switch runtime.GOOS {
case "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
wantSum = want.Len()
v := chunks
for v > 0 {
wantMinCalls++
v -= 1024
}
}
if len(writeLog.log) < wantMinCalls {
t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
}
if gotSum != wantSum {
t.Errorf("writev call sum = %v; want %v", gotSum, wantSum)
}
return nil
})
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package net
import (
"io"
"os"
"syscall"
"unsafe"
)
func (c *conn) writeBuffers(v *Buffers) (int64, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.fd.writeBuffers(v)
if err != nil {
return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, nil
}
func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(); err != nil {
return 0, err
}
var iovecs []syscall.Iovec
if fd.iovecs != nil {
iovecs = *fd.iovecs
}
// TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is
// 1024 and this seems conservative enough for now. Darwin's
// UIO_MAXIOV also seems to be 1024.
maxVec := 1024
for len(*v) > 0 {
iovecs = iovecs[:0]
for _, chunk := range *v {
if len(chunk) == 0 {
continue
}
iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]})
iovecs[len(iovecs)-1].SetLen(len(chunk))
if len(iovecs) == maxVec {
break
}
}
if len(iovecs) == 0 {
break
}
fd.iovecs = &iovecs // cache
wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV,
uintptr(fd.sysfd),
uintptr(unsafe.Pointer(&iovecs[0])),
uintptr(len(iovecs)))
if wrote < 0 {
wrote = 0
}
testHookDidWritev(int(wrote))
n += int64(wrote)
v.consume(int64(wrote))
if e0 == syscall.EAGAIN {
if err = fd.pd.waitWrite(); err == nil {
continue
}
} else if e0 != 0 {
err = syscall.Errno(e0)
}
if err != nil {
break
}
if n == 0 {
err = io.ErrUnexpectedEOF
break
}
}
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("writev", err)
}
return n, err
}
......@@ -296,3 +296,5 @@ func RouteRIB(facility, param int) ([]byte, error) { return nil,
func ParseRoutingMessage(b []byte) ([]RoutingMessage, error) { return nil, ENOSYS }
func ParseRoutingSockaddr(msg RoutingMessage) ([]Sockaddr, error) { return nil, ENOSYS }
func SysctlUint32(name string) (value uint32, err error) { return 0, ENOSYS }
type Iovec struct{} // dummy
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment