Commit e5b13401 authored by Minaev Mike's avatar Minaev Mike Committed by Filippo Valsorda

crypto/tls: fix deadlock when Read and Close called concurrently

The existing implementation of TLS connection has a deadlock. It occurs
when client connects to TLS server and doesn't send data for
handshake, so server calls Close on this connection. This is because
server reads data under locked mutex, while Close method tries to
lock the same mutex.

Fixes #23518

Change-Id: I4fb0a2a770f3d911036bfd9a7da7cc41c1b27e19
Reviewed-on: https://go-review.googlesource.com/90155
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarFilippo Valsorda <filippo@golang.org>
parent 7b46867d
...@@ -27,15 +27,16 @@ type Conn struct { ...@@ -27,15 +27,16 @@ type Conn struct {
conn net.Conn conn net.Conn
isClient bool isClient bool
// handshakeStatus is 1 if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
// This field is only to be accessed with sync/atomic.
handshakeStatus uint32
// constant after handshake; protected by handshakeMutex // constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex handshakeMutex sync.Mutex
handshakeErr error // error resulting from handshake handshakeErr error // error resulting from handshake
vers uint16 // TLS version vers uint16 // TLS version
haveVers bool // version has been negotiated haveVers bool // version has been negotiated
config *Config // configuration passed to constructor config *Config // configuration passed to constructor
// handshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
handshakeComplete bool
// handshakes counts the number of handshakes performed on the // handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either // connection so far. If renegotiation is disabled then this is either
// zero or one. // zero or one.
...@@ -571,12 +572,12 @@ func (c *Conn) readRecord(want recordType) error { ...@@ -571,12 +572,12 @@ func (c *Conn) readRecord(want recordType) error {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec: case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete { if c.handshakeComplete() {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
} }
case recordTypeApplicationData: case recordTypeApplicationData:
if !c.handshakeComplete { if !c.handshakeComplete() {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
} }
...@@ -1048,7 +1049,7 @@ func (c *Conn) Write(b []byte) (int, error) { ...@@ -1048,7 +1049,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err return 0, err
} }
if !c.handshakeComplete { if !c.handshakeComplete() {
return 0, alertInternalError return 0, alertInternalError
} }
...@@ -1114,7 +1115,7 @@ func (c *Conn) handleRenegotiation() error { ...@@ -1114,7 +1115,7 @@ func (c *Conn) handleRenegotiation() error {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
c.handshakeComplete = false atomic.StoreUint32(&c.handshakeStatus, 0)
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++ c.handshakes++
} }
...@@ -1215,11 +1216,9 @@ func (c *Conn) Close() error { ...@@ -1215,11 +1216,9 @@ func (c *Conn) Close() error {
var alertErr error var alertErr error
c.handshakeMutex.Lock() if c.handshakeComplete() {
if c.handshakeComplete {
alertErr = c.closeNotify() alertErr = c.closeNotify()
} }
c.handshakeMutex.Unlock()
if err := c.conn.Close(); err != nil { if err := c.conn.Close(); err != nil {
return err return err
...@@ -1233,9 +1232,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com ...@@ -1233,9 +1232,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
// called once the handshake has completed and does not call CloseWrite on the // called once the handshake has completed and does not call CloseWrite on the
// underlying connection. Most callers should just use Close. // underlying connection. Most callers should just use Close.
func (c *Conn) CloseWrite() error { func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock() if !c.handshakeComplete() {
defer c.handshakeMutex.Unlock()
if !c.handshakeComplete {
return errEarlyCloseWrite return errEarlyCloseWrite
} }
...@@ -1264,7 +1261,7 @@ func (c *Conn) Handshake() error { ...@@ -1264,7 +1261,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil { if err := c.handshakeErr; err != nil {
return err return err
} }
if c.handshakeComplete { if c.handshakeComplete() {
return nil return nil
} }
...@@ -1284,7 +1281,7 @@ func (c *Conn) Handshake() error { ...@@ -1284,7 +1281,7 @@ func (c *Conn) Handshake() error {
c.flush() c.flush()
} }
if c.handshakeErr == nil && !c.handshakeComplete { if c.handshakeErr == nil && !c.handshakeComplete() {
panic("handshake should have had a result.") panic("handshake should have had a result.")
} }
...@@ -1297,10 +1294,10 @@ func (c *Conn) ConnectionState() ConnectionState { ...@@ -1297,10 +1294,10 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
var state ConnectionState var state ConnectionState
state.HandshakeComplete = c.handshakeComplete state.HandshakeComplete = c.handshakeComplete()
state.ServerName = c.serverName state.ServerName = c.serverName
if c.handshakeComplete { if state.HandshakeComplete {
state.Version = c.vers state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume state.DidResume = c.didResume
...@@ -1345,7 +1342,7 @@ func (c *Conn) VerifyHostname(host string) error { ...@@ -1345,7 +1342,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient { if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection") return errors.New("tls: VerifyHostname called on TLS server connection")
} }
if !c.handshakeComplete { if !c.handshakeComplete() {
return errors.New("tls: handshake has not yet been performed") return errors.New("tls: handshake has not yet been performed")
} }
if len(c.verifiedChains) == 0 { if len(c.verifiedChains) == 0 {
...@@ -1353,3 +1350,7 @@ func (c *Conn) VerifyHostname(host string) error { ...@@ -1353,3 +1350,7 @@ func (c *Conn) VerifyHostname(host string) error {
} }
return c.peerCertificates[0].VerifyHostname(host) return c.peerCertificates[0].VerifyHostname(host)
} }
func (c *Conn) handshakeComplete() bool {
return atomic.LoadUint32(&c.handshakeStatus) == 1
}
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
) )
type clientHandshakeState struct { type clientHandshakeState struct {
...@@ -266,7 +267,7 @@ func (hs *clientHandshakeState) handshake() error { ...@@ -266,7 +267,7 @@ func (hs *clientHandshakeState) handshake() error {
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
c.didResume = isResume c.didResume = isResume
c.handshakeComplete = true atomic.StoreUint32(&c.handshakeStatus, 1)
return nil return nil
} }
......
...@@ -1617,3 +1617,22 @@ RwBA9Xk1KBNF ...@@ -1617,3 +1617,22 @@ RwBA9Xk1KBNF
t.Error("A RSA-PSS certificate was parsed like a PKCS1 one, and it will be mistakenly used with rsa_pss_rsae_xxx signature algorithms") t.Error("A RSA-PSS certificate was parsed like a PKCS1 one, and it will be mistakenly used with rsa_pss_rsae_xxx signature algorithms")
} }
} }
func TestCloseClientConnectionOnIdleServer(t *testing.T) {
clientConn, serverConn := net.Pipe()
client := Client(clientConn, testConfig.Clone())
go func() {
var b [1]byte
serverConn.Read(b[:])
client.Close()
}()
client.SetWriteDeadline(time.Now().Add(time.Second))
err := client.Handshake()
if err != nil {
if !strings.Contains(err.Error(), "read/write on closed pipe") {
t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error())
}
} else {
t.Errorf("Error expected, but no error returned")
}
}
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"sync/atomic"
) )
// serverHandshakeState contains details of a server handshake in progress. // serverHandshakeState contains details of a server handshake in progress.
...@@ -103,7 +104,7 @@ func (c *Conn) serverHandshake() error { ...@@ -103,7 +104,7 @@ func (c *Conn) serverHandshake() error {
} }
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
c.handshakeComplete = true atomic.StoreUint32(&c.handshakeStatus, 1)
return nil return nil
} }
......
...@@ -1403,3 +1403,21 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{ ...@@ -1403,3 +1403,21 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{
} }
var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75"))
func TestCloseServerConnectionOnIdleClient(t *testing.T) {
clientConn, serverConn := net.Pipe()
server := Server(serverConn, testConfig.Clone())
go func() {
clientConn.Write([]byte{'0'})
server.Close()
}()
server.SetReadDeadline(time.Now().Add(time.Second))
err := server.Handshake()
if err != nil {
if !strings.Contains(err.Error(), "read/write on closed pipe") {
t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error())
}
} else {
t.Errorf("Error expected, but no error returned")
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment