Commit c4807f6a authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: don't ignore errors in Request.Write

LGTM=josharian, adg
R=golang-codereviews, josharian, adg
CC=golang-codereviews
https://golang.org/cl/119110043
parent 8cb04077
......@@ -390,10 +390,16 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
w = bw
}
fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
_, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
if err != nil {
return err
}
// Header lines
fmt.Fprintf(w, "Host: %s\r\n", host)
_, err = fmt.Fprintf(w, "Host: %s\r\n", host)
if err != nil {
return err
}
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
......@@ -404,7 +410,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
}
}
if userAgent != "" {
fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
_, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
if err != nil {
return err
}
}
// Process Body,ContentLength,Close,Trailer
......@@ -429,7 +438,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
}
}
io.WriteString(w, "\r\n")
_, err = io.WriteString(w, "\r\n")
if err != nil {
return err
}
// Write body and trailer
err = tw.WriteBody(w)
......
......@@ -563,3 +563,61 @@ func mustParseURL(s string) *url.URL {
}
return u
}
type writerFunc func([]byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
// TestRequestWriteError tests the Write err != nil checks in (*Request).write.
func TestRequestWriteError(t *testing.T) {
failAfter, writeCount := 0, 0
errFail := errors.New("fake write failure")
// w is the buffered io.Writer to write the request to. It
// fails exactly once on its Nth Write call, as controlled by
// failAfter. It also tracks the number of calls in
// writeCount.
w := struct {
io.ByteWriter // to avoid being wrapped by a bufio.Writer
io.Writer
}{
nil,
writerFunc(func(p []byte) (n int, err error) {
writeCount++
if failAfter == 0 {
err = errFail
}
failAfter--
return len(p), err
}),
}
req, _ := NewRequest("GET", "http://example.com/", nil)
const writeCalls = 4 // number of Write calls in current implementation
sawGood := false
for n := 0; n <= writeCalls+2; n++ {
failAfter = n
writeCount = 0
err := req.Write(w)
var wantErr error
if n < writeCalls {
wantErr = errFail
}
if err != wantErr {
t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr)
continue
}
if err == nil {
sawGood = true
if writeCount != writeCalls {
t.Fatalf("writeCalls constant is outdated in test")
}
}
if writeCount > writeCalls || writeCount > n+1 {
t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount)
}
}
if !sawGood {
t.Fatalf("writeCalls constant is outdated in test")
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment