diff --git a/src/io/pipe.go b/src/io/pipe.go index 544481e1b97a97a15cd671948cf666ea1c045fce..4efaf2f8e481381360eff85756be70baaf3a545c 100644 --- a/src/io/pipe.go +++ b/src/io/pipe.go @@ -13,6 +13,18 @@ import ( "sync/atomic" ) +// atomicError is a type-safe atomic value for errors. +// We use a struct{ error } to ensure consistent use of a concrete type. +type atomicError struct{ v atomic.Value } + +func (a *atomicError) Store(err error) { + a.v.Store(struct{ error }{err}) +} +func (a *atomicError) Load() error { + err, _ := a.v.Load().(struct{ error }) + return err.error +} + // ErrClosedPipe is the error used for read or write operations on a closed pipe. var ErrClosedPipe = errors.New("io: read/write on closed pipe") @@ -24,8 +36,8 @@ type pipe struct { once sync.Once // Protects closing done done chan struct{} - rerr atomic.Value - werr atomic.Value + rerr atomicError + werr atomicError } func (p *pipe) Read(b []byte) (n int, err error) { @@ -46,8 +58,8 @@ func (p *pipe) Read(b []byte) (n int, err error) { } func (p *pipe) readCloseError() error { - _, rok := p.rerr.Load().(error) - if werr, wok := p.werr.Load().(error); !rok && wok { + rerr := p.rerr.Load() + if werr := p.werr.Load(); rerr == nil && werr != nil { return werr } return ErrClosedPipe @@ -85,8 +97,8 @@ func (p *pipe) Write(b []byte) (n int, err error) { } func (p *pipe) writeCloseError() error { - _, wok := p.werr.Load().(error) - if rerr, rok := p.rerr.Load().(error); !wok && rok { + werr := p.werr.Load() + if rerr := p.rerr.Load(); werr == nil && rerr != nil { return rerr } return ErrClosedPipe diff --git a/src/io/pipe_test.go b/src/io/pipe_test.go index 2bf95f03e30d9f259941cf0086172c337b9d7f47..f18b1c45f8b5c553df74176402b9fe0604baeb4a 100644 --- a/src/io/pipe_test.go +++ b/src/io/pipe_test.go @@ -316,6 +316,31 @@ func TestWriteAfterWriterClose(t *testing.T) { } } +func TestPipeCloseError(t *testing.T) { + type testError1 struct{ error } + type testError2 struct{ error } + + r, w := Pipe() + r.CloseWithError(testError1{}) + if _, err := w.Write(nil); err != (testError1{}) { + t.Errorf("Write error: got %T, want testError1", err) + } + r.CloseWithError(testError2{}) + if _, err := w.Write(nil); err != (testError2{}) { + t.Errorf("Write error: got %T, want testError2", err) + } + + r, w = Pipe() + w.CloseWithError(testError1{}) + if _, err := r.Read(nil); err != (testError1{}) { + t.Errorf("Read error: got %T, want testError1", err) + } + w.CloseWithError(testError2{}) + if _, err := r.Read(nil); err != (testError2{}) { + t.Errorf("Read error: got %T, want testError2", err) + } +} + func TestPipeConcurrent(t *testing.T) { const ( input = "0123456789abcdef"