Commit 37c28759 authored by Tamir Duberstein's avatar Tamir Duberstein Committed by Adam Langley

crypto/tls: check errors from (*Conn).writeRecord

This promotes a connection hang during TLS handshake to a proper error.
This doesn't fully address #14539 because the error reported in that
case is a write-on-socket-not-connected error, which implies that an
earlier error during connection setup is not being checked, but it is
an improvement over the current behaviour.

Updates #14539.

Change-Id: I0571a752d32d5303db48149ab448226868b19495
Reviewed-on: https://go-review.googlesource.com/19990Reviewed-by: default avatarAdam Langley <agl@golang.org>
parent 1012892f
...@@ -694,12 +694,14 @@ func (c *Conn) sendAlertLocked(err alert) error { ...@@ -694,12 +694,14 @@ func (c *Conn) sendAlertLocked(err alert) error {
c.tmp[0] = alertLevelError c.tmp[0] = alertLevelError
} }
c.tmp[1] = byte(err) c.tmp[1] = byte(err)
c.writeRecord(recordTypeAlert, c.tmp[0:2])
// closeNotify is a special case in that it isn't an error: _, writeErr := c.writeRecord(recordTypeAlert, c.tmp[0:2])
if err != alertCloseNotify { if err == alertCloseNotify {
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) // closeNotify is a special case in that it isn't an error.
return writeErr
} }
return nil
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
} }
// sendAlert sends a TLS alert message. // sendAlert sends a TLS alert message.
...@@ -713,8 +715,11 @@ func (c *Conn) sendAlert(err alert) error { ...@@ -713,8 +715,11 @@ func (c *Conn) sendAlert(err alert) error {
// writeRecord writes a TLS record with the given type and payload // writeRecord writes a TLS record with the given type and payload
// to the connection and updates the record layer state. // to the connection and updates the record layer state.
// c.out.Mutex <= L. // c.out.Mutex <= L.
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
b := c.out.newBlock() b := c.out.newBlock()
defer c.out.freeBlock(b)
var n int
for len(data) > 0 { for len(data) > 0 {
m := len(data) m := len(data)
if m > maxPlaintext { if m > maxPlaintext {
...@@ -759,34 +764,27 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { ...@@ -759,34 +764,27 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
if explicitIVIsSeq { if explicitIVIsSeq {
copy(explicitIV, c.out.seq[:]) copy(explicitIV, c.out.seq[:])
} else { } else {
if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
break return n, err
} }
} }
} }
copy(b.data[recordHeaderLen+explicitIVLen:], data) copy(b.data[recordHeaderLen+explicitIVLen:], data)
c.out.encrypt(b, explicitIVLen) c.out.encrypt(b, explicitIVLen)
_, err = c.conn.Write(b.data) if _, err := c.conn.Write(b.data); err != nil {
if err != nil { return n, err
break
} }
n += m n += m
data = data[m:] data = data[m:]
} }
c.out.freeBlock(b)
if typ == recordTypeChangeCipherSpec { if typ == recordTypeChangeCipherSpec {
err = c.out.changeCipherSpec() if err := c.out.changeCipherSpec(); err != nil {
if err != nil { return n, c.sendAlertLocked(err.(alert))
// Cannot call sendAlert directly,
// because we already hold c.out.Mutex.
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
} }
} }
return
return n, nil
} }
// readHandshake reads the next handshake message from // readHandshake reads the next handshake message from
......
...@@ -138,7 +138,9 @@ NextCipherSuite: ...@@ -138,7 +138,9 @@ NextCipherSuite:
} }
} }
c.writeRecord(recordTypeHandshake, hello.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
return err
}
msg, err := c.readHandshake() msg, err := c.readHandshake()
if err != nil { if err != nil {
...@@ -419,7 +421,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { ...@@ -419,7 +421,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
certMsg.certificates = chainToSend.Certificate certMsg.certificates = chainToSend.Certificate
} }
hs.finishedHash.Write(certMsg.marshal()) hs.finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
} }
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0]) preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
...@@ -429,7 +433,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { ...@@ -429,7 +433,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
if ckx != nil { if ckx != nil {
hs.finishedHash.Write(ckx.marshal()) hs.finishedHash.Write(ckx.marshal())
c.writeRecord(recordTypeHandshake, ckx.marshal()) if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
return err
}
} }
if chainToSend != nil { if chainToSend != nil {
...@@ -471,7 +477,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { ...@@ -471,7 +477,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(certVerify.marshal()) hs.finishedHash.Write(certVerify.marshal())
c.writeRecord(recordTypeHandshake, certVerify.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
return err
}
} }
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
...@@ -615,7 +623,9 @@ func (hs *clientHandshakeState) readSessionTicket() error { ...@@ -615,7 +623,9 @@ func (hs *clientHandshakeState) readSessionTicket() error {
func (hs *clientHandshakeState) sendFinished(out []byte) error { func (hs *clientHandshakeState) sendFinished(out []byte) error {
c := hs.c c := hs.c
c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
return err
}
if hs.serverHello.nextProtoNeg { if hs.serverHello.nextProtoNeg {
nextProto := new(nextProtoMsg) nextProto := new(nextProtoMsg)
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
...@@ -624,13 +634,17 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error { ...@@ -624,13 +634,17 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
c.clientProtocolFallback = fallback c.clientProtocolFallback = fallback
hs.finishedHash.Write(nextProto.marshal()) hs.finishedHash.Write(nextProto.marshal())
c.writeRecord(recordTypeHandshake, nextProto.marshal()) if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil {
return err
}
} }
finished := new(finishedMsg) finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
hs.finishedHash.Write(finished.marshal()) hs.finishedHash.Write(finished.marshal())
c.writeRecord(recordTypeHandshake, finished.marshal()) if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
return err
}
copy(out, finished.verifyData) copy(out, finished.verifyData)
return nil return nil
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
...@@ -725,3 +726,51 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { ...@@ -725,3 +726,51 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
} }
} }
// brokenConn wraps a net.Conn and causes all Writes after a certain number to
// fail with brokenConnErr.
type brokenConn struct {
net.Conn
// breakAfter is the number of successful writes that will be allowed
// before all subsequent writes fail.
breakAfter int
// numWrites is the number of writes that have been done.
numWrites int
}
// brokenConnErr is the error that brokenConn returns once exhausted.
var brokenConnErr = errors.New("too many writes to brokenConn")
func (b *brokenConn) Write(data []byte) (int, error) {
if b.numWrites >= b.breakAfter {
return 0, brokenConnErr
}
b.numWrites++
return b.Conn.Write(data)
}
func TestFailedWrite(t *testing.T) {
// Test that a write error during the handshake is returned.
for _, breakAfter := range []int{0, 1, 2, 3} {
c, s := net.Pipe()
done := make(chan bool)
go func() {
Server(s, testConfig).Handshake()
s.Close()
done <- true
}()
brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
err := Client(brokenC, testConfig).Handshake()
if err != brokenConnErr {
t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
}
brokenC.Close()
<-done
}
}
...@@ -322,7 +322,9 @@ func (hs *serverHandshakeState) doResumeHandshake() error { ...@@ -322,7 +322,9 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
hs.finishedHash.discardHandshakeBuffer() hs.finishedHash.discardHandshakeBuffer()
hs.finishedHash.Write(hs.clientHello.marshal()) hs.finishedHash.Write(hs.clientHello.marshal())
hs.finishedHash.Write(hs.hello.marshal()) hs.finishedHash.Write(hs.hello.marshal())
c.writeRecord(recordTypeHandshake, hs.hello.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
return err
}
if len(hs.sessionState.certificates) > 0 { if len(hs.sessionState.certificates) > 0 {
if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
...@@ -354,19 +356,25 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -354,19 +356,25 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(hs.clientHello.marshal()) hs.finishedHash.Write(hs.clientHello.marshal())
hs.finishedHash.Write(hs.hello.marshal()) hs.finishedHash.Write(hs.hello.marshal())
c.writeRecord(recordTypeHandshake, hs.hello.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
return err
}
certMsg := new(certificateMsg) certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate certMsg.certificates = hs.cert.Certificate
hs.finishedHash.Write(certMsg.marshal()) hs.finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
if hs.hello.ocspStapling { if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg) certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP certStatus.statusType = statusTypeOCSP
certStatus.response = hs.cert.OCSPStaple certStatus.response = hs.cert.OCSPStaple
hs.finishedHash.Write(certStatus.marshal()) hs.finishedHash.Write(certStatus.marshal())
c.writeRecord(recordTypeHandshake, certStatus.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
return err
}
} }
keyAgreement := hs.suite.ka(c.vers) keyAgreement := hs.suite.ka(c.vers)
...@@ -377,7 +385,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -377,7 +385,9 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
if skx != nil { if skx != nil {
hs.finishedHash.Write(skx.marshal()) hs.finishedHash.Write(skx.marshal())
c.writeRecord(recordTypeHandshake, skx.marshal()) if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
return err
}
} }
if config.ClientAuth >= RequestClientCert { if config.ClientAuth >= RequestClientCert {
...@@ -401,12 +411,16 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -401,12 +411,16 @@ func (hs *serverHandshakeState) doFullHandshake() error {
certReq.certificateAuthorities = config.ClientCAs.Subjects() certReq.certificateAuthorities = config.ClientCAs.Subjects()
} }
hs.finishedHash.Write(certReq.marshal()) hs.finishedHash.Write(certReq.marshal())
c.writeRecord(recordTypeHandshake, certReq.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
return err
}
} }
helloDone := new(serverHelloDoneMsg) helloDone := new(serverHelloDoneMsg)
hs.finishedHash.Write(helloDone.marshal()) hs.finishedHash.Write(helloDone.marshal())
c.writeRecord(recordTypeHandshake, helloDone.marshal()) if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any var pub crypto.PublicKey // public key for client auth, if any
...@@ -632,7 +646,9 @@ func (hs *serverHandshakeState) sendSessionTicket() error { ...@@ -632,7 +646,9 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
} }
hs.finishedHash.Write(m.marshal()) hs.finishedHash.Write(m.marshal())
c.writeRecord(recordTypeHandshake, m.marshal()) if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
return err
}
return nil return nil
} }
...@@ -640,12 +656,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { ...@@ -640,12 +656,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
func (hs *serverHandshakeState) sendFinished(out []byte) error { func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c c := hs.c
c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
return err
}
finished := new(finishedMsg) finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
hs.finishedHash.Write(finished.marshal()) hs.finishedHash.Write(finished.marshal())
c.writeRecord(recordTypeHandshake, finished.marshal()) if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
return err
}
c.cipherSuite = hs.suite.id c.cipherSuite = hs.suite.id
copy(out, finished.verifyData) copy(out, finished.verifyData)
......
...@@ -80,7 +80,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa ...@@ -80,7 +80,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
cli.writeRecord(recordTypeHandshake, m.marshal()) cli.writeRecord(recordTypeHandshake, m.marshal())
c.Close() c.Close()
}() }()
err := Server(s, serverConfig).Handshake() hs := serverHandshakeState{
c: Server(s, serverConfig),
}
_, err := hs.readClientHello()
s.Close() s.Close()
if len(expectedSubStr) == 0 { if len(expectedSubStr) == 0 {
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment