Commit f45eb9ff authored by Jordi Martin's avatar Jordi Martin Committed by Brad Fitzpatrick

io: add error check on pipe close functions to avoid error overwriting

The current implementation allows multiple calls `Close` and `CloseWithError` in every side of the pipe, as a result, the original error can be overwritten.

This CL fixes this behavior adding an error existence check on `atomicError` type
and keeping the first error still available.

Fixes #24283

Change-Id: Iefe8f758aeb775309424365f8177511062514150
GitHub-Last-Rev: b559540d7af3a0dad423816b695525ac2d6bd864
GitHub-Pull-Request: golang/go#33239
Reviewed-on: https://go-review.googlesource.com/c/go/+/187197Reviewed-by: default avatarJoe Tsai <thebrokentoaster@gmail.com>
Run-TryBot: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 89865f8b
...@@ -10,19 +10,26 @@ package io ...@@ -10,19 +10,26 @@ package io
import ( import (
"errors" "errors"
"sync" "sync"
"sync/atomic"
) )
// atomicError is a type-safe atomic value for errors. // onceError is an object that will only store an error once.
// We use a struct{ error } to ensure consistent use of a concrete type. type onceError struct {
type atomicError struct{ v atomic.Value } sync.Mutex // guards following
err error
}
func (a *atomicError) Store(err error) { func (a *onceError) Store(err error) {
a.v.Store(struct{ error }{err}) a.Lock()
defer a.Unlock()
if a.err != nil {
return
}
a.err = err
} }
func (a *atomicError) Load() error { func (a *onceError) Load() error {
err, _ := a.v.Load().(struct{ error }) a.Lock()
return err.error defer a.Unlock()
return a.err
} }
// ErrClosedPipe is the error used for read or write operations on a closed pipe. // ErrClosedPipe is the error used for read or write operations on a closed pipe.
...@@ -36,8 +43,8 @@ type pipe struct { ...@@ -36,8 +43,8 @@ type pipe struct {
once sync.Once // Protects closing done once sync.Once // Protects closing done
done chan struct{} done chan struct{}
rerr atomicError rerr onceError
werr atomicError werr onceError
} }
func (p *pipe) Read(b []byte) (n int, err error) { func (p *pipe) Read(b []byte) (n int, err error) {
...@@ -135,6 +142,9 @@ func (r *PipeReader) Close() error { ...@@ -135,6 +142,9 @@ func (r *PipeReader) Close() error {
// CloseWithError closes the reader; subsequent writes // CloseWithError closes the reader; subsequent writes
// to the write half of the pipe will return the error err. // to the write half of the pipe will return the error err.
//
// CloseWithError never overwrites the previous error if it exists
// and always returns nil.
func (r *PipeReader) CloseWithError(err error) error { func (r *PipeReader) CloseWithError(err error) error {
return r.p.CloseRead(err) return r.p.CloseRead(err)
} }
...@@ -163,7 +173,8 @@ func (w *PipeWriter) Close() error { ...@@ -163,7 +173,8 @@ func (w *PipeWriter) Close() error {
// read half of the pipe will return no bytes and the error err, // read half of the pipe will return no bytes and the error err,
// or EOF if err is nil. // or EOF if err is nil.
// //
// CloseWithError always returns nil. // CloseWithError never overwrites the previous error if it exists
// and always returns nil.
func (w *PipeWriter) CloseWithError(err error) error { func (w *PipeWriter) CloseWithError(err error) error {
return w.p.CloseWrite(err) return w.p.CloseWrite(err)
} }
......
...@@ -326,8 +326,8 @@ func TestPipeCloseError(t *testing.T) { ...@@ -326,8 +326,8 @@ func TestPipeCloseError(t *testing.T) {
t.Errorf("Write error: got %T, want testError1", err) t.Errorf("Write error: got %T, want testError1", err)
} }
r.CloseWithError(testError2{}) r.CloseWithError(testError2{})
if _, err := w.Write(nil); err != (testError2{}) { if _, err := w.Write(nil); err != (testError1{}) {
t.Errorf("Write error: got %T, want testError2", err) t.Errorf("Write error: got %T, want testError1", err)
} }
r, w = Pipe() r, w = Pipe()
...@@ -336,8 +336,8 @@ func TestPipeCloseError(t *testing.T) { ...@@ -336,8 +336,8 @@ func TestPipeCloseError(t *testing.T) {
t.Errorf("Read error: got %T, want testError1", err) t.Errorf("Read error: got %T, want testError1", err)
} }
w.CloseWithError(testError2{}) w.CloseWithError(testError2{})
if _, err := r.Read(nil); err != (testError2{}) { if _, err := r.Read(nil); err != (testError1{}) {
t.Errorf("Read error: got %T, want testError2", err) t.Errorf("Read error: got %T, want testError1", err)
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment