Commit 72d93220 authored by Russ Cox's avatar Russ Cox

crypto/tls: simpler implementation of record layer

Depends on CL 957045, 980043, 1004043.
Fixes #715.

R=agl1, agl
CC=golang-dev
https://golang.org/cl/943043
parent 47a05334
...@@ -7,15 +7,13 @@ include ../../../Make.$(GOARCH) ...@@ -7,15 +7,13 @@ include ../../../Make.$(GOARCH)
TARG=crypto/tls TARG=crypto/tls
GOFILES=\ GOFILES=\
alert.go\ alert.go\
ca_set.go\
common.go\ common.go\
conn.go\
handshake_client.go\ handshake_client.go\
handshake_messages.go\ handshake_messages.go\
handshake_server.go\ handshake_server.go\
prf.go\ prf.go\
record_process.go\
record_read.go\
record_write.go\
ca_set.go\
tls.go\ tls.go\
include ../../../Make.pkg include ../../../Make.pkg
...@@ -4,40 +4,70 @@ ...@@ -4,40 +4,70 @@
package tls package tls
type alertLevel int import "strconv"
type alertType int
type alert uint8
const ( const (
alertLevelWarning alertLevel = 1 // alert level
alertLevelError alertLevel = 2 alertLevelWarning = 1
alertLevelError = 2
) )
const ( const (
alertCloseNotify alertType = 0 alertCloseNotify alert = 0
alertUnexpectedMessage alertType = 10 alertUnexpectedMessage alert = 10
alertBadRecordMAC alertType = 20 alertBadRecordMAC alert = 20
alertDecryptionFailed alertType = 21 alertDecryptionFailed alert = 21
alertRecordOverflow alertType = 22 alertRecordOverflow alert = 22
alertDecompressionFailure alertType = 30 alertDecompressionFailure alert = 30
alertHandshakeFailure alertType = 40 alertHandshakeFailure alert = 40
alertBadCertificate alertType = 42 alertBadCertificate alert = 42
alertUnsupportedCertificate alertType = 43 alertUnsupportedCertificate alert = 43
alertCertificateRevoked alertType = 44 alertCertificateRevoked alert = 44
alertCertificateExpired alertType = 45 alertCertificateExpired alert = 45
alertCertificateUnknown alertType = 46 alertCertificateUnknown alert = 46
alertIllegalParameter alertType = 47 alertIllegalParameter alert = 47
alertUnknownCA alertType = 48 alertUnknownCA alert = 48
alertAccessDenied alertType = 49 alertAccessDenied alert = 49
alertDecodeError alertType = 50 alertDecodeError alert = 50
alertDecryptError alertType = 51 alertDecryptError alert = 51
alertProtocolVersion alertType = 70 alertProtocolVersion alert = 70
alertInsufficientSecurity alertType = 71 alertInsufficientSecurity alert = 71
alertInternalError alertType = 80 alertInternalError alert = 80
alertUserCanceled alertType = 90 alertUserCanceled alert = 90
alertNoRenegotiation alertType = 100 alertNoRenegotiation alert = 100
) )
type alert struct { var alertText = map[alert]string{
level alertLevel alertCloseNotify: "close notify",
error alertType alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
} }
...@@ -10,22 +10,18 @@ import ( ...@@ -10,22 +10,18 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"once" "once"
"os"
"time" "time"
) )
const ( const (
// maxTLSCiphertext is the maximum length of a plaintext payload. maxPlaintext = 16384 // maximum plaintext payload length
maxTLSPlaintext = 16384 maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
// maxTLSCiphertext is the maximum length payload after compression and encryption. recordHeaderLen = 5 // record header length
maxTLSCiphertext = 16384 + 2048 maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
// maxHandshakeMsg is the largest single handshake message that we'll buffer.
maxHandshakeMsg = 65536
// defaultMajor and defaultMinor are the maximum TLS version that we support.
defaultMajor = 3
defaultMinor = 2
)
minVersion = 0x0301 // minimum supported version - TLS 1.0
maxVersion = 0x0302 // maximum supported version - TLS 1.1
)
// TLS record types. // TLS record types.
type recordType uint8 type recordType uint8
...@@ -67,7 +63,7 @@ var ( ...@@ -67,7 +63,7 @@ var (
type ConnectionState struct { type ConnectionState struct {
HandshakeComplete bool HandshakeComplete bool
CipherSuite string CipherSuite string
Error alertType Error alert
NegotiatedProtocol string NegotiatedProtocol string
} }
...@@ -99,6 +95,7 @@ type record struct { ...@@ -99,6 +95,7 @@ type record struct {
type handshakeMessage interface { type handshakeMessage interface {
marshal() []byte marshal() []byte
unmarshal([]byte) bool
} }
type encryptor interface { type encryptor interface {
...@@ -108,34 +105,16 @@ type encryptor interface { ...@@ -108,34 +105,16 @@ type encryptor interface {
// mutualVersion returns the protocol version to use given the advertised // mutualVersion returns the protocol version to use given the advertised
// version of the peer. // version of the peer.
func mutualVersion(theirMajor, theirMinor uint8) (major, minor uint8, ok bool) { func mutualVersion(vers uint16) (uint16, bool) {
// We don't deal with peers < TLS 1.0 (aka version 3.1). if vers < minVersion {
if theirMajor < 3 || theirMajor == 3 && theirMinor < 1 { return 0, false
return 0, 0, false
} }
major = 3 if vers > maxVersion {
minor = 2 vers = maxVersion
if theirMinor < minor {
minor = theirMinor
} }
ok = true return vers, true
return
} }
// A nop implements the NULL encryption and MAC algorithms.
type nop struct{}
func (nop) XORKeyStream(buf []byte) {}
func (nop) Write(buf []byte) (int, os.Error) { return len(buf), nil }
func (nop) Sum() []byte { return nil }
func (nop) Reset() {}
func (nop) Size() int { return 0 }
// The defaultConfig is used in place of a nil *Config in the TLS server and client. // The defaultConfig is used in place of a nil *Config in the TLS server and client.
var varDefaultConfig *Config var varDefaultConfig *Config
......
// TLS low level connection and record layer
package tls
import (
"bytes"
"crypto/subtle"
"hash"
"io"
"net"
"os"
"sync"
)
// A Conn represents a secured connection.
// It implements the net.Conn interface.
type Conn struct {
// constant
conn net.Conn
isClient bool
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
handshakeComplete bool
cipherSuite uint16
clientProtocol string
// first permanent error
errMutex sync.Mutex
err os.Error
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
input *block // application data waiting to be read
hand bytes.Buffer // handshake data waiting to be read
tmp [16]byte
}
func (c *Conn) setError(err os.Error) os.Error {
c.errMutex.Lock()
defer c.errMutex.Unlock()
if c.err == nil {
c.err = err
}
return err
}
func (c *Conn) error() os.Error {
c.errMutex.Lock()
defer c.errMutex.Unlock()
return c.err
}
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetTimeout sets the read deadline associated with the connection.
// There is no write deadline.
func (c *Conn) SetTimeout(nsec int64) os.Error {
return c.conn.SetTimeout(nsec)
}
// SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning os.EAGAIN.
// Setting nsec == 0 (the default) disables the deadline.
func (c *Conn) SetReadTimeout(nsec int64) os.Error {
return c.conn.SetReadTimeout(nsec)
}
// SetWriteTimeout exists to satisfy the net.Conn interface
// but is not implemented by TLS. It always returns an error.
func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
return os.NewError("TLS does not support SetWriteTimeout")
}
// A halfConn represents one direction of the record layer
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
crypt encryptor // encryption state
mac hash.Hash // MAC algorithm
seq [8]byte // 64-bit sequence number
bfree *block // list of free blocks
nextCrypt encryptor // next encryption state
nextMac hash.Hash // next MAC algorithm
}
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(crypt encryptor, mac hash.Hash) {
hc.nextCrypt = crypt
hc.nextMac = mac
}
// changeCipherSpec changes the encryption and MAC states
// to the ones previously passed to prepareCipherSpec.
func (hc *halfConn) changeCipherSpec() os.Error {
if hc.nextCrypt == nil {
return alertInternalError
}
hc.crypt = hc.nextCrypt
hc.mac = hc.nextMac
hc.nextCrypt = nil
hc.nextMac = nil
return nil
}
// incSeq increments the sequence number.
func (hc *halfConn) incSeq() {
for i := 7; i >= 0; i-- {
hc.seq[i]++
if hc.seq[i] != 0 {
return
}
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
panic("TLS: sequence number wraparound")
}
// resetSeq resets the sequence number to zero.
func (hc *halfConn) resetSeq() {
for i := range hc.seq {
hc.seq[i] = 0
}
}
// decrypt checks and strips the mac and decrypts the data in b.
func (hc *halfConn) decrypt(b *block) (bool, alert) {
// pull out payload
payload := b.data[recordHeaderLen:]
// decrypt
if hc.crypt != nil {
hc.crypt.XORKeyStream(payload)
}
// check, strip mac
if hc.mac != nil {
if len(payload) < hc.mac.Size() {
return false, alertBadRecordMAC
}
// strip mac off payload, b.data
n := len(payload) - hc.mac.Size()
b.data[3] = byte(n >> 8)
b.data[4] = byte(n)
b.data = b.data[0 : recordHeaderLen+n]
remoteMAC := payload[n:]
hc.mac.Reset()
hc.mac.Write(&hc.seq)
hc.incSeq()
hc.mac.Write(b.data)
if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 {
return false, alertBadRecordMAC
}
}
return true, 0
}
// encrypt encrypts and macs the data in b.
func (hc *halfConn) encrypt(b *block) (bool, alert) {
// mac
if hc.mac != nil {
hc.mac.Reset()
hc.mac.Write(&hc.seq)
hc.incSeq()
hc.mac.Write(b.data)
mac := hc.mac.Sum()
n := len(b.data)
b.resize(n + len(mac))
copy(b.data[n:], mac)
// update length to include mac
n = len(b.data) - recordHeaderLen
b.data[3] = byte(n >> 8)
b.data[4] = byte(n)
}
// encrypt
if hc.crypt != nil {
hc.crypt.XORKeyStream(b.data[recordHeaderLen:])
}
return true, 0
}
// A block is a simple data buffer.
type block struct {
data []byte
off int // index for Read
link *block
}
// resize resizes block to be n bytes, growing if necessary.
func (b *block) resize(n int) {
if n > cap(b.data) {
b.reserve(n)
}
b.data = b.data[0:n]
}
// reserve makes sure that block contains a capacity of at least n bytes.
func (b *block) reserve(n int) {
if cap(b.data) >= n {
return
}
m := cap(b.data)
if m == 0 {
m = 1024
}
for m < n {
m *= 2
}
data := make([]byte, len(b.data), m)
copy(data, b.data)
b.data = data
}
// readFromUntil reads from r into b until b contains at least n bytes
// or else returns an error.
func (b *block) readFromUntil(r io.Reader, n int) os.Error {
// quick case
if len(b.data) >= n {
return nil
}
// read until have enough.
b.reserve(n)
for {
m, err := r.Read(b.data[len(b.data):cap(b.data)])
b.data = b.data[0 : len(b.data)+m]
if len(b.data) >= n {
break
}
if err != nil {
return err
}
}
return nil
}
func (b *block) Read(p []byte) (n int, err os.Error) {
n = copy(p, b.data[b.off:])
b.off += n
return
}
// newBlock allocates a new block, from hc's free list if possible.
func (hc *halfConn) newBlock() *block {
b := hc.bfree
if b == nil {
return new(block)
}
hc.bfree = b.link
b.link = nil
b.resize(0)
return b
}
// freeBlock returns a block to hc's free list.
// The protocol is such that each side only has a block or two on
// its free list at a time, so there's no need to worry about
// trimming the list, etc.
func (hc *halfConn) freeBlock(b *block) {
b.link = hc.bfree
hc.bfree = b
}
// splitBlock splits a block after the first n bytes,
// returning a block with those n bytes and a
// block with the remaindec. the latter may be nil.
func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
if len(b.data) <= n {
return b, nil
}
bb := hc.newBlock()
bb.resize(len(b.data) - n)
copy(bb.data, b.data[n:])
b.data = b.data[0:n]
return b, bb
}
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
// c.in.Mutex <= L; c.input == nil.
func (c *Conn) readRecord(want recordType) os.Error {
// Caller must be in sync with connection:
// handshake data if handshake not yet completed,
// else application data. (We don't support renegotiation.)
switch want {
default:
return c.sendAlert(alertInternalError)
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete {
return c.sendAlert(alertInternalError)
}
case recordTypeApplicationData:
if !c.handshakeComplete {
return c.sendAlert(alertInternalError)
}
}
Again:
if c.rawInput == nil {
c.rawInput = c.in.newBlock()
}
b := c.rawInput
// Read header, payload.
if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC suggests that EOF without an alertCloseNotify is
// an error, but popular web sites seem to do this,
// so we can't make it an error.
// if err == os.EOF {
// err = io.ErrUnexpectedEOF
// }
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.setError(err)
}
return err
}
typ := recordType(b.data[0])
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if c.haveVers && vers != c.vers {
return c.sendAlert(alertProtocolVersion)
}
if n > maxCiphertext {
return c.sendAlert(alertRecordOverflow)
}
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
if err == os.EOF {
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.setError(err)
}
return err
}
// Process message.
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
b.off = recordHeaderLen
if ok, err := c.in.decrypt(b); !ok {
return c.sendAlert(err)
}
data := b.data[b.off:]
if len(data) > maxPlaintext {
c.sendAlert(alertRecordOverflow)
c.in.freeBlock(b)
return c.error()
}
switch typ {
default:
c.sendAlert(alertUnexpectedMessage)
case recordTypeAlert:
if len(data) != 2 {
c.sendAlert(alertUnexpectedMessage)
break
}
if alert(data[1]) == alertCloseNotify {
c.setError(os.EOF)
break
}
switch data[0] {
case alertLevelWarning:
// drop on the floor
c.in.freeBlock(b)
goto Again
case alertLevelError:
c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
default:
c.sendAlert(alertUnexpectedMessage)
}
case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
c.sendAlert(alertUnexpectedMessage)
break
}
err := c.in.changeCipherSpec()
if err != nil {
c.sendAlert(err.(alert))
}
case recordTypeApplicationData:
if typ != want {
c.sendAlert(alertUnexpectedMessage)
break
}
c.input = b
b = nil
case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want {
return c.sendAlert(alertNoRenegotiation)
}
c.hand.Write(data)
}
if b != nil {
c.in.freeBlock(b)
}
return c.error()
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlertLocked(err alert) os.Error {
c.tmp[0] = alertLevelError
if err == alertNoRenegotiation {
c.tmp[0] = alertLevelWarning
}
c.tmp[1] = byte(err)
c.writeRecord(recordTypeAlert, c.tmp[0:2])
return c.setError(&net.OpError{Op: "local error", Error: err})
}
// sendAlert sends a TLS alert message.
// L < c.out.Mutex.
func (c *Conn) sendAlert(err alert) os.Error {
c.out.Lock()
defer c.out.Unlock()
return c.sendAlertLocked(err)
}
// writeRecord writes a TLS record with the given type and payload
// to the connection and updates the record layer state.
// c.out.Mutex <= L.
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
b := c.out.newBlock()
for len(data) > 0 {
m := len(data)
if m > maxPlaintext {
m = maxPlaintext
}
b.resize(recordHeaderLen + m)
b.data[0] = byte(typ)
vers := c.vers
if vers == 0 {
vers = maxVersion
}
b.data[1] = byte(vers >> 8)
b.data[2] = byte(vers)
b.data[3] = byte(m >> 8)
b.data[4] = byte(m)
copy(b.data[recordHeaderLen:], data)
c.out.encrypt(b)
_, err = c.conn.Write(b.data)
if err != nil {
break
}
n += m
data = data[m:]
}
c.out.freeBlock(b)
if typ == recordTypeChangeCipherSpec {
err = c.out.changeCipherSpec()
if err != nil {
// Cannot call sendAlert directly,
// because we already hold c.out.Mutex.
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
c.err = &net.OpError{Op: "local error", Error: err}
return n, c.err
}
}
return
}
// readHandshake reads the next handshake message from
// the record layer.
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, os.Error) {
for c.hand.Len() < 4 {
if c.err != nil {
return nil, c.err
}
c.readRecord(recordTypeHandshake)
}
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlert(alertInternalError)
return nil, c.err
}
for c.hand.Len() < 4+n {
if c.err != nil {
return nil, c.err
}
c.readRecord(recordTypeHandshake)
}
data = c.hand.Next(4 + n)
var m handshakeMessage
switch data[0] {
case typeClientHello:
m = new(clientHelloMsg)
case typeServerHello:
m = new(serverHelloMsg)
case typeCertificate:
m = new(certificateMsg)
case typeServerHelloDone:
m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
case typeNextProtocol:
m = new(nextProtoMsg)
case typeFinished:
m = new(finishedMsg)
default:
c.sendAlert(alertUnexpectedMessage)
return nil, alertUnexpectedMessage
}
// The handshake message unmarshallers
// expect to be able to keep references to data,
// so pass in a fresh copy that won't be overwritten.
data = bytes.Add(nil, data)
if !m.unmarshal(data) {
c.sendAlert(alertUnexpectedMessage)
return nil, alertUnexpectedMessage
}
return m, nil
}
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (n int, err os.Error) {
if err = c.Handshake(); err != nil {
return
}
c.out.Lock()
defer c.out.Unlock()
if !c.handshakeComplete {
return 0, alertInternalError
}
if c.err != nil {
return 0, c.err
}
return c.writeRecord(recordTypeApplicationData, b)
}
// Read can be made to time out and return err == os.EAGAIN
// after a fixed time limit; see SetTimeout and SetReadTimeout.
func (c *Conn) Read(b []byte) (n int, err os.Error) {
if err = c.Handshake(); err != nil {
return
}
c.in.Lock()
defer c.in.Unlock()
for c.input == nil && c.err == nil {
c.readRecord(recordTypeApplicationData)
}
if c.err != nil {
return 0, c.err
}
n, err = c.input.Read(b)
if c.input.off >= len(c.input.data) {
c.in.freeBlock(c.input)
c.input = nil
}
return n, nil
}
// Close closes the connection.
func (c *Conn) Close() os.Error {
if err := c.Handshake(); err != nil {
return err
}
return c.sendAlert(alertCloseNotify)
}
// Handshake runs the client or server handshake
// protocol if it has not yet been run.
// Most uses of this packge need not call Handshake
// explicitly: the first Read or Write will call it automatically.
func (c *Conn) Handshake() os.Error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.error(); err != nil {
return err
}
if c.handshakeComplete {
return nil
}
if c.isClient {
return c.clientHandshake()
}
return c.serverHandshake()
}
// If c is a TLS server, ClientConnection returns the protocol
// requested by the client during the TLS handshake.
// Handshake must have been called already.
func (c *Conn) ClientConnection() string {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.clientProtocol
}
...@@ -12,74 +12,63 @@ import ( ...@@ -12,74 +12,63 @@ import (
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
"io" "io"
"os"
) )
// A serverHandshake performs the server side of the TLS 1.1 handshake protocol. func (c *Conn) clientHandshake() os.Error {
type clientHandshake struct {
writeChan chan<- interface{}
controlChan chan<- interface{}
msgChan <-chan interface{}
config *Config
}
func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
h.writeChan = writeChan
h.controlChan = controlChan
h.msgChan = msgChan
h.config = config
defer close(writeChan)
defer close(controlChan)
finishedHash := newFinishedHash() finishedHash := newFinishedHash()
config := defaultConfig()
hello := &clientHelloMsg{ hello := &clientHelloMsg{
major: defaultMajor, vers: maxVersion,
minor: defaultMinor,
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
random: make([]byte, 32), random: make([]byte, 32),
} }
currentTime := uint32(config.Time()) t := uint32(config.Time())
hello.random[0] = byte(currentTime >> 24) hello.random[0] = byte(t >> 24)
hello.random[1] = byte(currentTime >> 16) hello.random[1] = byte(t >> 16)
hello.random[2] = byte(currentTime >> 8) hello.random[2] = byte(t >> 8)
hello.random[3] = byte(currentTime) hello.random[3] = byte(t)
_, err := io.ReadFull(config.Rand, hello.random[4:]) _, err := io.ReadFull(config.Rand, hello.random[4:])
if err != nil { if err != nil {
h.error(alertInternalError) return c.sendAlert(alertInternalError)
return
} }
finishedHash.Write(hello.marshal()) finishedHash.Write(hello.marshal())
writeChan <- writerSetVersion{defaultMajor, defaultMinor} c.writeRecord(recordTypeHandshake, hello.marshal())
writeChan <- hello
serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg) msg, err := c.readHandshake()
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
finishedHash.Write(serverHello.marshal()) finishedHash.Write(serverHello.marshal())
major, minor, ok := mutualVersion(serverHello.major, serverHello.minor)
vers, ok := mutualVersion(serverHello.vers)
if !ok { if !ok {
h.error(alertProtocolVersion) c.sendAlert(alertProtocolVersion)
return
} }
c.vers = vers
writeChan <- writerSetVersion{major, minor} c.haveVers = true
if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA || if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA ||
serverHello.compressionMethod != compressionNone { serverHello.compressionMethod != compressionNone {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
certMsg, ok := h.readHandshakeMsg().(*certificateMsg) msg, err = c.readHandshake()
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsg)
if !ok || len(certMsg.certificates) == 0 { if !ok || len(certMsg.certificates) == 0 {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
...@@ -87,139 +76,98 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- ...@@ -87,139 +76,98 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
for i, asn1Data := range certMsg.certificates { for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data) cert, err := x509.ParseCertificate(asn1Data)
if err != nil { if err != nil {
h.error(alertBadCertificate) return c.sendAlert(alertBadCertificate)
return
} }
certs[i] = cert certs[i] = cert
} }
// TODO(agl): do better validation of certs: max path length, name restrictions etc. // TODO(agl): do better validation of certs: max path length, name restrictions etc.
for i := 1; i < len(certs); i++ { for i := 1; i < len(certs); i++ {
if certs[i-1].CheckSignatureFrom(certs[i]) != nil { if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
h.error(alertBadCertificate) return c.sendAlert(alertBadCertificate)
return
} }
} }
if config.RootCAs != nil { // TODO(rsc): Find certificates for OS X 10.6.
if false && config.RootCAs != nil {
root := config.RootCAs.FindParent(certs[len(certs)-1]) root := config.RootCAs.FindParent(certs[len(certs)-1])
if root == nil { if root == nil {
h.error(alertBadCertificate) return c.sendAlert(alertBadCertificate)
return
} }
if certs[len(certs)-1].CheckSignatureFrom(root) != nil { if certs[len(certs)-1].CheckSignatureFrom(root) != nil {
h.error(alertBadCertificate) return c.sendAlert(alertBadCertificate)
return
} }
} }
pub, ok := certs[0].PublicKey.(*rsa.PublicKey) pub, ok := certs[0].PublicKey.(*rsa.PublicKey)
if !ok { if !ok {
h.error(alertUnsupportedCertificate) return c.sendAlert(alertUnsupportedCertificate)
return
} }
shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg) msg, err = c.readHandshake()
if err != nil {
return err
}
shd, ok := msg.(*serverHelloDoneMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
finishedHash.Write(shd.marshal()) finishedHash.Write(shd.marshal())
ckx := new(clientKeyExchangeMsg) ckx := new(clientKeyExchangeMsg)
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
// Note that the version number in the preMasterSecret must be the preMasterSecret[0] = byte(hello.vers >> 8)
// version offered in the ClientHello. preMasterSecret[1] = byte(hello.vers)
preMasterSecret[0] = defaultMajor
preMasterSecret[1] = defaultMinor
_, err = io.ReadFull(config.Rand, preMasterSecret[2:]) _, err = io.ReadFull(config.Rand, preMasterSecret[2:])
if err != nil { if err != nil {
h.error(alertInternalError) return c.sendAlert(alertInternalError)
return
} }
ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret) ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret)
if err != nil { if err != nil {
h.error(alertInternalError) return c.sendAlert(alertInternalError)
return
} }
finishedHash.Write(ckx.marshal()) finishedHash.Write(ckx.marshal())
writeChan <- ckx c.writeRecord(recordTypeHandshake, ckx.marshal())
suite := cipherSuites[0] suite := cipherSuites[0]
masterSecret, clientMAC, serverMAC, clientKey, serverKey := masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength) keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength)
cipher, _ := rc4.NewCipher(clientKey) cipher, _ := rc4.NewCipher(clientKey)
writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg) finished := new(finishedMsg)
finished.verifyData = finishedHash.clientSum(masterSecret) finished.verifyData = finishedHash.clientSum(masterSecret)
finishedHash.Write(finished.marshal()) finishedHash.Write(finished.marshal())
writeChan <- finished c.writeRecord(recordTypeHandshake, finished.marshal())
// TODO(agl): this is cut-through mode which should probably be an option.
writeChan <- writerEnableApplicationData{}
_, ok = h.readHandshakeMsg().(changeCipherSpec)
if !ok {
h.error(alertUnexpectedMessage)
return
}
cipher2, _ := rc4.NewCipher(serverKey) cipher2, _ := rc4.NewCipher(serverKey)
controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)} c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil {
return c.err
}
serverFinished, ok := h.readHandshakeMsg().(*finishedMsg) msg, err = c.readHandshake()
if err != nil {
return err
}
serverFinished, ok := msg.(*finishedMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
verify := finishedHash.serverSum(masterSecret) verify := finishedHash.serverSum(masterSecret)
if len(verify) != len(serverFinished.verifyData) || if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
h.error(alertHandshakeFailure) return c.sendAlert(alertHandshakeFailure)
return
} }
controlChan <- ConnectionState{HandshakeComplete: true, CipherSuite: "TLS_RSA_WITH_RC4_128_SHA"} c.handshakeComplete = true
c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA
// This should just block forever. return nil
_ = h.readHandshakeMsg()
h.error(alertUnexpectedMessage)
return
}
func (h *clientHandshake) readHandshakeMsg() interface{} {
v := <-h.msgChan
if closed(h.msgChan) {
// If the channel closed then the processor received an error
// from the peer and we don't want to echo it back to them.
h.msgChan = nil
return 0
}
if _, ok := v.(alert); ok {
// We got an alert from the processor. We forward to the writer
// and shutdown.
h.writeChan <- v
h.msgChan = nil
return 0
}
return v
}
func (h *clientHandshake) error(e alertType) {
if h.msgChan != nil {
// If we didn't get an error from the processor, then we need
// to tell it about the error.
go func() {
for _ = range h.msgChan {
}
}()
h.controlChan <- ConnectionState{Error: e}
close(h.controlChan)
h.writeChan <- alert{alertLevelError, e}
}
} }
...@@ -6,7 +6,7 @@ package tls ...@@ -6,7 +6,7 @@ package tls
type clientHelloMsg struct { type clientHelloMsg struct {
raw []byte raw []byte
major, minor uint8 vers uint16
random []byte random []byte
sessionId []byte sessionId []byte
cipherSuites []uint16 cipherSuites []uint16
...@@ -40,8 +40,8 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -40,8 +40,8 @@ func (m *clientHelloMsg) marshal() []byte {
x[1] = uint8(length >> 16) x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8) x[2] = uint8(length >> 8)
x[3] = uint8(length) x[3] = uint8(length)
x[4] = m.major x[4] = uint8(m.vers >> 8)
x[5] = m.minor x[5] = uint8(m.vers)
copy(x[6:38], m.random) copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId)) x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId) copy(x[39:39+len(m.sessionId)], m.sessionId)
...@@ -108,12 +108,11 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -108,12 +108,11 @@ func (m *clientHelloMsg) marshal() []byte {
} }
func (m *clientHelloMsg) unmarshal(data []byte) bool { func (m *clientHelloMsg) unmarshal(data []byte) bool {
if len(data) < 43 { if len(data) < 42 {
return false return false
} }
m.raw = data m.raw = data
m.major = data[4] m.vers = uint16(data[4])<<8 | uint16(data[5])
m.minor = data[5]
m.random = data[6:38] m.random = data[6:38]
sessionIdLen := int(data[38]) sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen { if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
...@@ -136,7 +135,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { ...@@ -136,7 +135,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
} }
data = data[2+cipherSuiteLen:] data = data[2+cipherSuiteLen:]
if len(data) < 2 { if len(data) < 1 {
return false return false
} }
compressionMethodsLen := int(data[0]) compressionMethodsLen := int(data[0])
...@@ -212,7 +211,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { ...@@ -212,7 +211,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
type serverHelloMsg struct { type serverHelloMsg struct {
raw []byte raw []byte
major, minor uint8 vers uint16
random []byte random []byte
sessionId []byte sessionId []byte
cipherSuite uint16 cipherSuite uint16
...@@ -249,8 +248,8 @@ func (m *serverHelloMsg) marshal() []byte { ...@@ -249,8 +248,8 @@ func (m *serverHelloMsg) marshal() []byte {
x[1] = uint8(length >> 16) x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8) x[2] = uint8(length >> 8)
x[3] = uint8(length) x[3] = uint8(length)
x[4] = m.major x[4] = uint8(m.vers >> 8)
x[5] = m.minor x[5] = uint8(m.vers)
copy(x[6:38], m.random) copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId)) x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId) copy(x[39:39+len(m.sessionId)], m.sessionId)
...@@ -306,8 +305,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { ...@@ -306,8 +305,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return false return false
} }
m.raw = data m.raw = data
m.major = data[4] m.vers = uint16(data[4])<<8 | uint16(data[5])
m.minor = data[5]
m.random = data[6:38] m.random = data[6:38]
sessionIdLen := int(data[38]) sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen { if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
......
...@@ -97,8 +97,7 @@ func randomString(n int, rand *rand.Rand) string { ...@@ -97,8 +97,7 @@ func randomString(n int, rand *rand.Rand) string {
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientHelloMsg{} m := &clientHelloMsg{}
m.major = uint8(rand.Intn(256)) m.vers = uint16(rand.Intn(65536))
m.minor = uint8(rand.Intn(256))
m.random = randomBytes(32, rand) m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand) m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuites = make([]uint16, rand.Intn(63)+1) m.cipherSuites = make([]uint16, rand.Intn(63)+1)
...@@ -118,8 +117,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -118,8 +117,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg{} m := &serverHelloMsg{}
m.major = uint8(rand.Intn(256)) m.vers = uint16(rand.Intn(65536))
m.minor = uint8(rand.Intn(256))
m.random = randomBytes(32, rand) m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand) m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuite = uint16(rand.Int31()) m.cipherSuite = uint16(rand.Int31())
......
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"crypto/sha1" "crypto/sha1"
"crypto/subtle" "crypto/subtle"
"io" "io"
"os"
) )
type cipherSuite struct { type cipherSuite struct {
...@@ -31,33 +32,22 @@ var cipherSuites = []cipherSuite{ ...@@ -31,33 +32,22 @@ var cipherSuites = []cipherSuite{
cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 20, 16}, cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 20, 16},
} }
// A serverHandshake performs the server side of the TLS 1.1 handshake protocol. func (c *Conn) serverHandshake() os.Error {
type serverHandshake struct { config := c.config
writeChan chan<- interface{} msg, err := c.readHandshake()
controlChan chan<- interface{} if err != nil {
msgChan <-chan interface{} return err
config *Config }
} clientHello, ok := msg.(*clientHelloMsg)
func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
h.writeChan = writeChan
h.controlChan = controlChan
h.msgChan = msgChan
h.config = config
defer close(writeChan)
defer close(controlChan)
clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
major, minor, ok := mutualVersion(clientHello.major, clientHello.minor) vers, ok := mutualVersion(clientHello.vers)
if !ok { if !ok {
h.error(alertProtocolVersion) return c.sendAlert(alertProtocolVersion)
return
} }
c.vers = vers
c.haveVers = true
finishedHash := newFinishedHash() finishedHash := newFinishedHash()
finishedHash.Write(clientHello.marshal()) finishedHash.Write(clientHello.marshal())
...@@ -89,23 +79,20 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- ...@@ -89,23 +79,20 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
} }
if suite == nil || !foundCompression { if suite == nil || !foundCompression {
h.error(alertHandshakeFailure) return c.sendAlert(alertHandshakeFailure)
return
} }
hello.major = major hello.vers = vers
hello.minor = minor
hello.cipherSuite = suite.id hello.cipherSuite = suite.id
currentTime := uint32(config.Time()) t := uint32(config.Time())
hello.random = make([]byte, 32) hello.random = make([]byte, 32)
hello.random[0] = byte(currentTime >> 24) hello.random[0] = byte(t >> 24)
hello.random[1] = byte(currentTime >> 16) hello.random[1] = byte(t >> 16)
hello.random[2] = byte(currentTime >> 8) hello.random[2] = byte(t >> 8)
hello.random[3] = byte(currentTime) hello.random[3] = byte(t)
_, err := io.ReadFull(config.Rand, hello.random[4:]) _, err = io.ReadFull(config.Rand, hello.random[4:])
if err != nil { if err != nil {
h.error(alertInternalError) return c.sendAlert(alertInternalError)
return
} }
hello.compressionMethod = compressionNone hello.compressionMethod = compressionNone
if clientHello.nextProtoNeg { if clientHello.nextProtoNeg {
...@@ -114,41 +101,40 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- ...@@ -114,41 +101,40 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
} }
finishedHash.Write(hello.marshal()) finishedHash.Write(hello.marshal())
writeChan <- writerSetVersion{major, minor} c.writeRecord(recordTypeHandshake, hello.marshal())
writeChan <- hello
if len(config.Certificates) == 0 { if len(config.Certificates) == 0 {
h.error(alertInternalError) return c.sendAlert(alertInternalError)
return
} }
certMsg := new(certificateMsg) certMsg := new(certificateMsg)
certMsg.certificates = config.Certificates[0].Certificate certMsg.certificates = config.Certificates[0].Certificate
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
writeChan <- certMsg c.writeRecord(recordTypeHandshake, certMsg.marshal())
helloDone := new(serverHelloDoneMsg) helloDone := new(serverHelloDoneMsg)
finishedHash.Write(helloDone.marshal()) finishedHash.Write(helloDone.marshal())
writeChan <- helloDone c.writeRecord(recordTypeHandshake, helloDone.marshal())
ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg) msg, err = c.readHandshake()
if err != nil {
return err
}
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
finishedHash.Write(ckx.marshal()) finishedHash.Write(ckx.marshal())
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
_, err = io.ReadFull(config.Rand, preMasterSecret[2:]) _, err = io.ReadFull(config.Rand, preMasterSecret[2:])
if err != nil { if err != nil {
h.error(alertInternalError) return c.sendAlert(alertInternalError)
return
} }
err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
if err != nil { if err != nil {
h.error(alertHandshakeFailure) return c.sendAlert(alertHandshakeFailure)
return
} }
// We don't check the version number in the premaster secret. For one, // We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the // by checking it, we would leak information about the validity of the
...@@ -160,91 +146,53 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- ...@@ -160,91 +146,53 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
masterSecret, clientMAC, serverMAC, clientKey, serverKey := masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength) keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength)
_, ok = h.readHandshakeMsg().(changeCipherSpec)
if !ok {
h.error(alertUnexpectedMessage)
return
}
cipher, _ := rc4.NewCipher(clientKey) cipher, _ := rc4.NewCipher(clientKey)
controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)} c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
return err
}
clientProtocol := ""
if hello.nextProtoNeg { if hello.nextProtoNeg {
nextProto, ok := h.readHandshakeMsg().(*nextProtoMsg) msg, err = c.readHandshake()
if err != nil {
return err
}
nextProto, ok := msg.(*nextProtoMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
finishedHash.Write(nextProto.marshal()) finishedHash.Write(nextProto.marshal())
clientProtocol = nextProto.proto c.clientProtocol = nextProto.proto
} }
clientFinished, ok := h.readHandshakeMsg().(*finishedMsg) msg, err = c.readHandshake()
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok { if !ok {
h.error(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
return
} }
verify := finishedHash.clientSum(masterSecret) verify := finishedHash.clientSum(masterSecret)
if len(verify) != len(clientFinished.verifyData) || if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
h.error(alertHandshakeFailure) return c.sendAlert(alertHandshakeFailure)
return
} }
controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, clientProtocol}
finishedHash.Write(clientFinished.marshal()) finishedHash.Write(clientFinished.marshal())
cipher2, _ := rc4.NewCipher(serverKey) cipher2, _ := rc4.NewCipher(serverKey)
writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)} c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg) finished := new(finishedMsg)
finished.verifyData = finishedHash.serverSum(masterSecret) finished.verifyData = finishedHash.serverSum(masterSecret)
writeChan <- finished c.writeRecord(recordTypeHandshake, finished.marshal())
writeChan <- writerEnableApplicationData{} c.handshakeComplete = true
c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA
for { return nil
_, ok := h.readHandshakeMsg().(*clientHelloMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
}
// We reject all renegotication requests.
writeChan <- alert{alertLevelWarning, alertNoRenegotiation}
}
}
func (h *serverHandshake) readHandshakeMsg() interface{} {
v := <-h.msgChan
if closed(h.msgChan) {
// If the channel closed then the processor received an error
// from the peer and we don't want to echo it back to them.
h.msgChan = nil
return 0
}
if _, ok := v.(alert); ok {
// We got an alert from the processor. We forward to the writer
// and shutdown.
h.writeChan <- v
h.msgChan = nil
return 0
}
return v
}
func (h *serverHandshake) error(e alertType) {
if h.msgChan != nil {
// If we didn't get an error from the processor, then we need
// to tell it about the error.
go func() {
for _ = range h.msgChan {
}
}()
h.controlChan <- ConnectionState{false, "", e, ""}
close(h.controlChan)
h.writeChan <- alert{alertLevelError, e}
}
} }
...@@ -5,12 +5,16 @@ ...@@ -5,12 +5,16 @@
package tls package tls
import ( import (
"bytes" // "bytes"
"big" "big"
"crypto/rsa" "crypto/rsa"
"encoding/hex"
"flag"
"io"
"net"
"os" "os"
"testing" "testing"
"testing/script" // "testing/script"
) )
type zeroSource struct{} type zeroSource struct{}
...@@ -34,29 +38,23 @@ func init() { ...@@ -34,29 +38,23 @@ func init() {
testConfig.Certificates[0].PrivateKey = testPrivateKey testConfig.Certificates[0].PrivateKey = testPrivateKey
} }
func setupServerHandshake() (writeChan chan interface{}, controlChan chan interface{}, msgChan chan interface{}) { func testClientHelloFailure(t *testing.T, m handshakeMessage, expected os.Error) {
sh := new(serverHandshake) // Create in-memory network connection,
writeChan = make(chan interface{}) // send message to server. Should return
controlChan = make(chan interface{}) // expected error.
msgChan = make(chan interface{}) c, s := net.Pipe()
go func() {
go sh.loop(writeChan, controlChan, msgChan, testConfig) cli := Client(c, testConfig)
return if ch, ok := m.(*clientHelloMsg); ok {
} cli.vers = ch.vers
}
func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert alertType) { cli.writeRecord(recordTypeHandshake, m.marshal())
writeChan, controlChan, msgChan := setupServerHandshake() c.Close()
defer close(msgChan) }()
err := Server(s, testConfig).Handshake()
send := script.NewEvent("send", nil, script.Send{msgChan, clientHello}) s.Close()
recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}}) if e, ok := err.(*net.OpError); !ok || e.Error != expected {
close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan}) t.Errorf("Got error: %s; expected: %s", err, expected)
recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert, ""}})
close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan})
err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2})
if err != nil {
t.Errorf("Got error: %s", err)
} }
} }
...@@ -64,147 +62,232 @@ func TestSimpleError(t *testing.T) { ...@@ -64,147 +62,232 @@ func TestSimpleError(t *testing.T) {
testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage) testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage)
} }
var badProtocolVersions = []uint8{0, 0, 0, 5, 1, 0, 1, 5, 2, 0, 2, 5, 3, 0} var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205, 0x0300}
func TestRejectBadProtocolVersion(t *testing.T) { func TestRejectBadProtocolVersion(t *testing.T) {
clientHello := new(clientHelloMsg) for _, v := range badProtocolVersions {
testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion)
for i := 0; i < len(badProtocolVersions); i += 2 {
clientHello.major = badProtocolVersions[i]
clientHello.minor = badProtocolVersions[i+1]
testClientHelloFailure(t, clientHello, alertProtocolVersion)
} }
} }
func TestNoSuiteOverlap(t *testing.T) { func TestNoSuiteOverlap(t *testing.T) {
clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""} clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
testClientHelloFailure(t, clientHello, alertHandshakeFailure) testClientHelloFailure(t, clientHello, alertHandshakeFailure)
} }
func TestNoCompressionOverlap(t *testing.T) { func TestNoCompressionOverlap(t *testing.T) {
clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""} clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
testClientHelloFailure(t, clientHello, alertHandshakeFailure) testClientHelloFailure(t, clientHello, alertHandshakeFailure)
} }
func matchServerHello(v interface{}) bool { func TestAlertForwarding(t *testing.T) {
serverHello, ok := v.(*serverHelloMsg) c, s := net.Pipe()
if !ok { go func() {
return false Client(c, testConfig).sendAlert(alertUnknownCA)
c.Close()
}()
err := Server(s, testConfig).Handshake()
s.Close()
if e, ok := err.(*net.OpError); !ok || e.Error != os.Error(alertUnknownCA) {
t.Errorf("Got error: %s; expected: %s", err, alertUnknownCA)
} }
return serverHello.major == 3 &&
serverHello.minor == 2 &&
serverHello.cipherSuite == TLS_RSA_WITH_RC4_128_SHA &&
serverHello.compressionMethod == compressionNone
} }
func TestAlertForwarding(t *testing.T) { func TestClose(t *testing.T) {
writeChan, controlChan, msgChan := setupServerHandshake() c, s := net.Pipe()
defer close(msgChan) go c.Close()
a := alert{alertLevelError, alertNoRenegotiation}
sendAlert := script.NewEvent("send alert", nil, script.Send{msgChan, a})
recvAlert := script.NewEvent("recv alert", []*script.Event{sendAlert}, script.Recv{writeChan, a})
closeWriter := script.NewEvent("close writer", []*script.Event{recvAlert}, script.Closed{writeChan})
closeControl := script.NewEvent("close control", []*script.Event{recvAlert}, script.Closed{controlChan})
err := script.Perform(0, []*script.Event{sendAlert, recvAlert, closeWriter, closeControl}) err := Server(s, testConfig).Handshake()
if err != nil { s.Close()
t.Errorf("Got error: %s", err) if err != os.EOF {
t.Errorf("Got error: %s; expected: %s", err, os.EOF)
} }
} }
func TestClose(t *testing.T) {
writeChan, controlChan, msgChan := setupServerHandshake()
close := script.NewEvent("close", nil, script.Close{msgChan}) func TestHandshakeServer(t *testing.T) {
closed1 := script.NewEvent("closed1", []*script.Event{close}, script.Closed{writeChan}) c, s := net.Pipe()
closed2 := script.NewEvent("closed2", []*script.Event{close}, script.Closed{controlChan}) srv := Server(s, testConfig)
go func() {
srv.Write([]byte("hello, world\n"))
srv.Close()
}()
err := script.Perform(0, []*script.Event{close, closed1, closed2}) defer c.Close()
for i, b := range serverScript {
if i%2 == 0 {
c.Write(b)
continue
}
bb := make([]byte, len(b))
_, err := io.ReadFull(c, bb)
if err != nil { if err != nil {
t.Errorf("Got error: %s", err) t.Fatalf("#%d: %s", i, err)
} }
}
func matchCertificate(v interface{}) bool {
cert, ok := v.(*certificateMsg)
if !ok {
return false
} }
return len(cert.certificates) == 1 &&
bytes.Compare(cert.certificates[0], testCertificate) == 0
}
func matchSetCipher(v interface{}) bool {
_, ok := v.(writerChangeCipherSpec)
return ok
}
func matchDone(v interface{}) bool { if !srv.haveVers || srv.vers != 0x0302 {
_, ok := v.(*serverHelloDoneMsg) t.Errorf("server version incorrect: %v %v", srv.haveVers, srv.vers)
return ok
}
func matchFinished(v interface{}) bool {
finished, ok := v.(*finishedMsg)
if !ok {
return false
} }
return bytes.Compare(finished.verifyData, fromHex("29122ae11453e631487b02ed")) == 0
}
func matchNewCipherSpec(v interface{}) bool { // TODO: check protocol
_, ok := v.(*newCipherSpec)
return ok
} }
func TestFullHandshake(t *testing.T) { var serve = flag.Bool("serve", false, "run a TLS server on :10443")
writeChan, controlChan, msgChan := setupServerHandshake()
defer close(msgChan)
// The values for this test were obtained from running `gnutls-cli --insecure --debug 9` func TestRunServer(t *testing.T) {
clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}, false, ""} if !*serve {
return
sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello}) }
setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}})
recvHello := script.NewEvent("recv hello", []*script.Event{setVersion}, script.RecvMatch{writeChan, matchServerHello})
recvCert := script.NewEvent("recv cert", []*script.Event{recvHello}, script.RecvMatch{writeChan, matchCertificate})
recvDone := script.NewEvent("recv done", []*script.Event{recvCert}, script.RecvMatch{writeChan, matchDone})
ckx := &clientKeyExchangeMsg{nil, fromHex("872e1fee5f37dd86f3215938ac8de20b302b90074e9fb93097e6b7d1286d0f45abf2daf179deb618bb3c70ed0afee6ee24476ee4649e5a23358143c0f1d9c251")}
sendCKX := script.NewEvent("send ckx", []*script.Event{recvDone}, script.Send{msgChan, ckx})
sendCCS := script.NewEvent("send ccs", []*script.Event{sendCKX}, script.Send{msgChan, changeCipherSpec{}})
recvNCS := script.NewEvent("recv done", []*script.Event{sendCCS}, script.RecvMatch{controlChan, matchNewCipherSpec})
finished := &finishedMsg{nil, fromHex("c8faca5d242f4423325c5b1a")} l, err := Listen("tcp", ":10443", testConfig)
sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished}) if err != nil {
recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished}) t.Fatal(err)
setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher}) }
recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, ""}})
err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished}) for {
c, err := l.Accept()
if err != nil { if err != nil {
t.Errorf("Got error: %s", err) break
}
c.Write([]byte("hello, world\n"))
c.Close()
} }
} }
var testCertificate = fromHex("3082025930820203a003020102020900c2ec326b95228959300d06092a864886f70d01010505003054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374301e170d3039313032303232323434355a170d3130313032303232323434355a3054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374305c300d06092a864886f70d0101010500034b003048024100b2990f49c47dfa8cd400ae6a4d1b8a3b6a13642b23f28b003bfb97790ade9a4cc82b8b2a81747ddec08b6296e53a08c331687ef25c4bf4936ba1c0e6041e9d150203010001a381b73081b4301d0603551d0e0416041478a06086837c9293a8c9b70c0bdabdb9d77eeedf3081840603551d23047d307b801478a06086837c9293a8c9b70c0bdabdb9d77eeedfa158a4563054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374820900c2ec326b95228959300c0603551d13040530030101ff300d06092a864886f70d0101050500034100ac23761ae1349d85a439caad4d0b932b09ea96de1917c3e0507c446f4838cb3076fb4d431db8c1987e96f1d7a8a2054dea3a64ec99a3f0eda4d47a163bf1f6ac")
func bigFromString(s string) *big.Int { func bigFromString(s string) *big.Int {
ret := new(big.Int) ret := new(big.Int)
ret.SetString(s, 10) ret.SetString(s, 10)
return ret return ret
} }
func fromHex(s string) []byte {
b, _ := hex.DecodeString(s)
return b
}
var testCertificate = fromHex("308202b030820219a00302010202090085b0bba48a7fb8ca300d06092a864886f70d01010505003045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3130303432343039303933385a170d3131303432343039303933385a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819f300d06092a864886f70d010101050003818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a381a73081a4301d0603551d0e04160414b1ade2855acfcb28db69ce2369ded3268e18883930750603551d23046e306c8014b1ade2855acfcb28db69ce2369ded3268e188839a149a4473045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746482090085b0bba48a7fb8ca300c0603551d13040530030101ff300d06092a864886f70d010105050003818100086c4524c76bb159ab0c52ccf2b014d7879d7a6475b55a9566e4c52b8eae12661feb4f38b36e60d392fdf74108b52513b1187a24fb301dbaed98b917ece7d73159db95d31d78ea50565cd5825a2d5a5f33c4b6d8c97590968c0f5298b5cd981f89205ff2a01ca31b9694dda9fd57e970e8266d71999b266e3850296c90a7bdd9")
var testPrivateKey = &rsa.PrivateKey{ var testPrivateKey = &rsa.PrivateKey{
PublicKey: rsa.PublicKey{ PublicKey: rsa.PublicKey{
N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), N: bigFromString("131650079503776001033793877885499001334664249354723305978524647182322416328664556247316495448366990052837680518067798333412266673813370895702118944398081598789828837447552603077848001020611640547221687072142537202428102790818451901395596882588063427854225330436740647715202971973145151161964464812406232198521"),
E: 65537, E: 65537,
}, },
D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"),
P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"),
Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"),
}
// Script of interaction with gnutls implementation.
// The values for this test are obtained by building a test binary (gotest)
// and then running 6.out -serve to start a server and then
// gnutls-cli --insecure --debug 100 -p 10443 localhost
// to dump a session.
var serverScript = [][]byte{
// Alternate write and read.
[]byte{
0x16, 0x03, 0x02, 0x00, 0x71, 0x01, 0x00, 0x00, 0x6d, 0x03, 0x02, 0x4b, 0xd4, 0xee, 0x6e, 0xab,
0x0b, 0xc3, 0x01, 0xd6, 0x8d, 0xe0, 0x72, 0x7e, 0x6c, 0x04, 0xbe, 0x9a, 0x3c, 0xa3, 0xd8, 0x95,
0x28, 0x00, 0xb2, 0xe8, 0x1f, 0xdd, 0xb0, 0xec, 0xca, 0x46, 0x1f, 0x00, 0x00, 0x28, 0x00, 0x33,
0x00, 0x39, 0x00, 0x16, 0x00, 0x32, 0x00, 0x38, 0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x00, 0x05, 0x00, 0x04, 0x00, 0x8c,
0x00, 0x8d, 0x00, 0x8b, 0x00, 0x8a, 0x01, 0x00, 0x00, 0x1c, 0x00, 0x09, 0x00, 0x03, 0x02, 0x00,
0x01, 0x00, 0x00, 0x00, 0x11, 0x00, 0x0f, 0x00, 0x00, 0x0c, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36,
0x38, 0x2e, 0x30, 0x2e, 0x31, 0x30,
},
[]byte{
0x16, 0x03, 0x02, 0x00, 0x2a,
0x02, 0x00, 0x00, 0x26, 0x03, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00,
0x16, 0x03, 0x02, 0x02, 0xbe,
0x0b, 0x00, 0x02, 0xba, 0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82, 0x02, 0xb0, 0x30, 0x82,
0x02, 0x19, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f,
0xb8, 0xca, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
0x00, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55,
0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d,
0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18,
0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73,
0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x30, 0x30, 0x34,
0x32, 0x34, 0x30, 0x39, 0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x34, 0x32,
0x34, 0x30, 0x39, 0x30, 0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08,
0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f,
0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20,
0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30,
0x81, 0x9f, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05,
0x00, 0x03, 0x81, 0x8d, 0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00, 0xbb, 0x79, 0xd6, 0xf5,
0x17, 0xb5, 0xe5, 0xbf, 0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b, 0x07, 0x43, 0x5a, 0xd0,
0x03, 0x2d, 0x8a, 0x7a, 0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65, 0x4c, 0x2c, 0x78, 0xb8,
0x23, 0x8c, 0xb5, 0xb4, 0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62, 0xa5, 0x2c, 0xa5, 0x33,
0xd6, 0xfe, 0x12, 0x5c, 0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58, 0x7b, 0x26, 0x3f, 0xb5,
0xcd, 0x04, 0xd3, 0xd0, 0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f, 0x5a, 0xbf, 0xef, 0x42,
0x71, 0x00, 0xfe, 0x18, 0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1, 0x04, 0x39, 0xc4, 0xa2,
0x2e, 0xdb, 0x51, 0xc9, 0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01, 0xcf, 0xaf, 0xb1, 0x1d,
0xb8, 0x71, 0x9a, 0x1d, 0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79, 0x02, 0x03, 0x01, 0x00,
0x01, 0xa3, 0x81, 0xa7, 0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
0x04, 0x14, 0xb1, 0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69, 0xce, 0x23, 0x69, 0xde,
0xd3, 0x26, 0x8e, 0x18, 0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x6e, 0x30,
0x6c, 0x80, 0x14, 0xb1, 0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69, 0xce, 0x23, 0x69,
0xde, 0xd3, 0x26, 0x8e, 0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30, 0x45, 0x31, 0x0b, 0x30,
0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03,
0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31,
0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e,
0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c,
0x74, 0x64, 0x82, 0x09, 0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0c, 0x06,
0x03, 0x55, 0x1d, 0x13, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x0d, 0x06, 0x09, 0x2a,
0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81, 0x81, 0x00, 0x08, 0x6c,
0x45, 0x24, 0xc7, 0x6b, 0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0, 0x14, 0xd7, 0x87, 0x9d,
0x7a, 0x64, 0x75, 0xb5, 0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae, 0x12, 0x66, 0x1f, 0xeb,
0x4f, 0x38, 0xb3, 0x6e, 0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5, 0x25, 0x13, 0xb1, 0x18,
0x7a, 0x24, 0xfb, 0x30, 0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7, 0xd7, 0x31, 0x59, 0xdb,
0x95, 0xd3, 0x1d, 0x78, 0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d, 0x5a, 0x5f, 0x33, 0xc4,
0xb6, 0xd8, 0xc9, 0x75, 0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd, 0x98, 0x1f, 0x89, 0x20,
0x5f, 0xf2, 0xa0, 0x1c, 0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57, 0xe9, 0x70, 0xe8, 0x26,
0x6d, 0x71, 0x99, 0x9b, 0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7, 0xbd, 0xd9,
0x16, 0x03, 0x02, 0x00, 0x04,
0x0e, 0x00, 0x00, 0x00,
},
[]byte{
0x16, 0x03, 0x02, 0x00, 0x86, 0x10, 0x00, 0x00, 0x82, 0x00, 0x80, 0x3b, 0x7a, 0x9b, 0x05, 0xfd,
0x1b, 0x0d, 0x81, 0xf0, 0xac, 0x59, 0x57, 0x4e, 0xb6, 0xf5, 0x81, 0xed, 0x52, 0x78, 0xc5, 0xff,
0x36, 0x33, 0x9c, 0x94, 0x31, 0xc3, 0x14, 0x98, 0x5d, 0xa0, 0x49, 0x23, 0x11, 0x67, 0xdf, 0x73,
0x1b, 0x81, 0x0b, 0xdd, 0x10, 0xda, 0xee, 0xb5, 0x68, 0x61, 0xa9, 0xb6, 0x15, 0xae, 0x1a, 0x11,
0x31, 0x42, 0x2e, 0xde, 0x01, 0x4b, 0x81, 0x70, 0x03, 0xc8, 0x5b, 0xca, 0x21, 0x88, 0x25, 0xef,
0x89, 0xf0, 0xb7, 0xff, 0x24, 0x32, 0xd3, 0x14, 0x76, 0xe2, 0x50, 0x5c, 0x2e, 0x75, 0x9d, 0x5c,
0xa9, 0x80, 0x3d, 0x6f, 0xd5, 0x46, 0xd3, 0xdb, 0x42, 0x6e, 0x55, 0x81, 0x88, 0x42, 0x0e, 0x45,
0xfe, 0x9e, 0xe4, 0x41, 0x79, 0xcf, 0x71, 0x0e, 0xed, 0x27, 0xa8, 0x20, 0x05, 0xe9, 0x7a, 0x42,
0x4f, 0x05, 0x10, 0x2e, 0x52, 0x5d, 0x8c, 0x3c, 0x40, 0x49, 0x4c,
0x14, 0x03, 0x02, 0x00, 0x01, 0x01,
0x16, 0x03, 0x02, 0x00, 0x24, 0x8b, 0x12, 0x24, 0x06, 0xaa, 0x92, 0x74, 0xa1, 0x46, 0x6f, 0xc1,
0x4e, 0x4a, 0xf7, 0x16, 0xdd, 0xd6, 0xe1, 0x2d, 0x37, 0x0b, 0x44, 0xba, 0xeb, 0xc4, 0x6c, 0xc7,
0xa0, 0xb7, 0x8c, 0x9d, 0x24, 0xbd, 0x99, 0x33, 0x1e,
},
[]byte{
0x14, 0x03, 0x02, 0x00, 0x01,
0x01,
0x16, 0x03, 0x02, 0x00, 0x24,
0x6e, 0xd1, 0x3e, 0x49, 0x68, 0xc1, 0xa0, 0xa5, 0xb7, 0xaf, 0xb0, 0x7c, 0x52, 0x1f, 0xf7, 0x2d,
0x51, 0xf3, 0xa5, 0xb6, 0xf6, 0xd4, 0x18, 0x4b, 0x7a, 0xd5, 0x24, 0x1d, 0x09, 0xb6, 0x41, 0x1c,
0x1c, 0x98, 0xf6, 0x90,
0x17, 0x03, 0x02, 0x00, 0x21,
0x50, 0xb7, 0x92, 0x4f, 0xd8, 0x78, 0x29, 0xa2, 0xe7, 0xa5, 0xa6, 0xbd, 0x1a, 0x0c, 0xf1, 0x5a,
0x6e, 0x6c, 0xeb, 0x38, 0x99, 0x9b, 0x3c, 0xfd, 0xee, 0x53, 0xe8, 0x4d, 0x7b, 0xa5, 0x5b, 0x00,
0xb9,
0x15, 0x03, 0x02, 0x00, 0x16,
0xc7, 0xc9, 0x5a, 0x72, 0xfb, 0x02, 0xa5, 0x93, 0xdd, 0x69, 0xeb, 0x30, 0x68, 0x5e, 0xbc, 0xe0,
0x44, 0xb9, 0x59, 0x33, 0x68, 0xa9,
},
} }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
// A recordProcessor accepts reassembled records, decrypts and verifies them
// and routes them either to the handshake processor, to up to the application.
// It also accepts requests from the application for the current connection
// state, or for a notification when the state changes.
import (
"container/list"
"crypto/subtle"
"hash"
)
// getConnectionState is a request from the application to get the current
// ConnectionState.
type getConnectionState struct {
reply chan<- ConnectionState
}
// waitConnectionState is a request from the application to be notified when
// the connection state changes.
type waitConnectionState struct {
reply chan<- ConnectionState
}
// connectionStateChange is a message from the handshake processor that the
// connection state has changed.
type connectionStateChange struct {
connState ConnectionState
}
// changeCipherSpec is a message send to the handshake processor to signal that
// the peer is switching ciphers.
type changeCipherSpec struct{}
// newCipherSpec is a message from the handshake processor that future
// records should be processed with a new cipher and MAC function.
type newCipherSpec struct {
encrypt encryptor
mac hash.Hash
}
type recordProcessor struct {
decrypt encryptor
mac hash.Hash
seqNum uint64
handshakeBuf []byte
appDataChan chan<- []byte
requestChan <-chan interface{}
controlChan <-chan interface{}
recordChan <-chan *record
handshakeChan chan<- interface{}
// recordRead is nil when we don't wish to read any more.
recordRead <-chan *record
// appDataSend is nil when len(appData) == 0.
appDataSend chan<- []byte
// appData contains any application data queued for upstream.
appData []byte
// A list of channels waiting for connState to change.
waitQueue *list.List
connState ConnectionState
shutdown bool
header [13]byte
}
// drainRequestChannel processes messages from the request channel until it's closed.
func drainRequestChannel(requestChan <-chan interface{}, c ConnectionState) {
for v := range requestChan {
if closed(requestChan) {
return
}
switch r := v.(type) {
case getConnectionState:
r.reply <- c
case waitConnectionState:
r.reply <- c
}
}
}
func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan interface{}, controlChan <-chan interface{}, recordChan <-chan *record, handshakeChan chan<- interface{}) {
noop := nop{}
p.decrypt = noop
p.mac = noop
p.waitQueue = list.New()
p.appDataChan = appDataChan
p.requestChan = requestChan
p.controlChan = controlChan
p.recordChan = recordChan
p.handshakeChan = handshakeChan
p.recordRead = recordChan
for !p.shutdown {
select {
case p.appDataSend <- p.appData:
p.appData = nil
p.appDataSend = nil
p.recordRead = p.recordChan
case c := <-controlChan:
p.processControlMsg(c)
case r := <-requestChan:
p.processRequestMsg(r)
case r := <-p.recordRead:
p.processRecord(r)
}
}
p.wakeWaiters()
go drainRequestChannel(p.requestChan, p.connState)
go func() {
for _ = range controlChan {
}
}()
close(handshakeChan)
if len(p.appData) > 0 {
appDataChan <- p.appData
}
close(appDataChan)
}
func (p *recordProcessor) processRequestMsg(requestMsg interface{}) {
if closed(p.requestChan) {
p.shutdown = true
return
}
switch r := requestMsg.(type) {
case getConnectionState:
r.reply <- p.connState
case waitConnectionState:
if p.connState.HandshakeComplete {
r.reply <- p.connState
}
p.waitQueue.PushBack(r.reply)
}
}
func (p *recordProcessor) processControlMsg(msg interface{}) {
connState, ok := msg.(ConnectionState)
if !ok || closed(p.controlChan) {
p.shutdown = true
return
}
p.connState = connState
p.wakeWaiters()
}
func (p *recordProcessor) wakeWaiters() {
for i := p.waitQueue.Front(); i != nil; i = i.Next() {
i.Value.(chan<- ConnectionState) <- p.connState
}
p.waitQueue.Init()
}
func (p *recordProcessor) processRecord(r *record) {
if closed(p.recordChan) {
p.shutdown = true
return
}
p.decrypt.XORKeyStream(r.payload)
if len(r.payload) < p.mac.Size() {
p.error(alertBadRecordMAC)
return
}
fillMACHeader(&p.header, p.seqNum, len(r.payload)-p.mac.Size(), r)
p.seqNum++
p.mac.Reset()
p.mac.Write(p.header[0:13])
p.mac.Write(r.payload[0 : len(r.payload)-p.mac.Size()])
macBytes := p.mac.Sum()
if subtle.ConstantTimeCompare(macBytes, r.payload[len(r.payload)-p.mac.Size():]) != 1 {
p.error(alertBadRecordMAC)
return
}
switch r.contentType {
case recordTypeHandshake:
p.processHandshakeRecord(r.payload[0 : len(r.payload)-p.mac.Size()])
case recordTypeChangeCipherSpec:
if len(r.payload) != 1 || r.payload[0] != 1 {
p.error(alertUnexpectedMessage)
return
}
p.handshakeChan <- changeCipherSpec{}
newSpec, ok := (<-p.controlChan).(*newCipherSpec)
if !ok {
p.connState.Error = alertUnexpectedMessage
p.shutdown = true
return
}
p.decrypt = newSpec.encrypt
p.mac = newSpec.mac
p.seqNum = 0
case recordTypeApplicationData:
if p.connState.HandshakeComplete == false {
p.error(alertUnexpectedMessage)
return
}
p.recordRead = nil
p.appData = r.payload[0 : len(r.payload)-p.mac.Size()]
p.appDataSend = p.appDataChan
default:
p.error(alertUnexpectedMessage)
return
}
}
func (p *recordProcessor) processHandshakeRecord(data []byte) {
if p.handshakeBuf == nil {
p.handshakeBuf = data
} else {
if len(p.handshakeBuf) > maxHandshakeMsg {
p.error(alertInternalError)
return
}
newBuf := make([]byte, len(p.handshakeBuf)+len(data))
copy(newBuf, p.handshakeBuf)
copy(newBuf[len(p.handshakeBuf):], data)
p.handshakeBuf = newBuf
}
for len(p.handshakeBuf) >= 4 {
handshakeLen := int(p.handshakeBuf[1])<<16 |
int(p.handshakeBuf[2])<<8 |
int(p.handshakeBuf[3])
if handshakeLen+4 > len(p.handshakeBuf) {
break
}
bytes := p.handshakeBuf[0 : handshakeLen+4]
p.handshakeBuf = p.handshakeBuf[handshakeLen+4:]
if bytes[0] == typeFinished {
// Special case because Finished is synchronous: the
// handshake handler has to tell us if it's ok to start
// forwarding application data.
m := new(finishedMsg)
if !m.unmarshal(bytes) {
p.error(alertUnexpectedMessage)
}
p.handshakeChan <- m
var ok bool
p.connState, ok = (<-p.controlChan).(ConnectionState)
if !ok || p.connState.Error != 0 {
p.shutdown = true
return
}
} else {
msg, ok := parseHandshakeMsg(bytes)
if !ok {
p.error(alertUnexpectedMessage)
return
}
p.handshakeChan <- msg
}
}
}
func (p *recordProcessor) error(err alertType) {
close(p.handshakeChan)
p.connState.Error = err
p.wakeWaiters()
p.shutdown = true
}
func parseHandshakeMsg(data []byte) (interface{}, bool) {
var m interface {
unmarshal([]byte) bool
}
switch data[0] {
case typeClientHello:
m = new(clientHelloMsg)
case typeServerHello:
m = new(serverHelloMsg)
case typeCertificate:
m = new(certificateMsg)
case typeServerHelloDone:
m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
case typeNextProtocol:
m = new(nextProtoMsg)
default:
return nil, false
}
ok := m.unmarshal(data)
return m, ok
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"encoding/hex"
"testing"
"testing/script"
)
func setup() (appDataChan chan []byte, requestChan chan interface{}, controlChan chan interface{}, recordChan chan *record, handshakeChan chan interface{}) {
rp := new(recordProcessor)
appDataChan = make(chan []byte)
requestChan = make(chan interface{})
controlChan = make(chan interface{})
recordChan = make(chan *record)
handshakeChan = make(chan interface{})
go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan)
return
}
func fromHex(s string) []byte {
b, _ := hex.DecodeString(s)
return b
}
func TestNullConnectionState(t *testing.T) {
_, requestChan, controlChan, recordChan, _ := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test a simple request for the connection state.
replyChan := make(chan ConnectionState)
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}})
getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0, ""}})
err := script.Perform(0, []*script.Event{sendReq, getReply})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestWaitConnectionState(t *testing.T) {
_, requestChan, controlChan, recordChan, _ := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test that waitConnectionState doesn't get a reply until the connection state changes.
replyChan := make(chan ConnectionState)
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}})
replyChan2 := make(chan ConnectionState)
sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}})
getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0, ""}})
sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1, ""}})
getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1, ""}})
err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestHandshakeAssembly(t *testing.T) {
_, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test the reassembly of a fragmented handshake message.
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}})
send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}})
send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}})
recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}})
err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestEarlyApplicationData(t *testing.T) {
_, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test that applicaton data received before the handshake has completed results in an error.
send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}})
recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan})
err := script.Perform(0, []*script.Event{send, recv})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestApplicationData(t *testing.T) {
appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test that the application data is forwarded after a successful Finished message.
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}})
recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}})
send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0, ""}})
send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}})
recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}})
err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestInvalidChangeCipherSpec(t *testing.T) {
appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}})
recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}})
send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42, ""}})
close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan})
close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan})
err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
// The record reader handles reading from the connection and reassembling TLS
// record structures. It loops forever doing this and writes the TLS records to
// it's outbound channel. On error, it closes its outbound channel.
import (
"io"
"bufio"
)
// recordReader loops, reading TLS records from source and writing them to the
// given channel. The channel is closed on EOF or on error.
func recordReader(c chan<- *record, source io.Reader) {
defer close(c)
buf := bufio.NewReader(source)
for {
var header [5]byte
n, _ := buf.Read(&header)
if n != 5 {
return
}
recordLength := int(header[3])<<8 | int(header[4])
if recordLength > maxTLSCiphertext {
return
}
payload := make([]byte, recordLength)
n, _ = buf.Read(payload)
if n != recordLength {
return
}
c <- &record{recordType(header[0]), header[1], header[2], payload}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"testing"
"testing/iotest"
)
func matchRecord(r1, r2 *record) bool {
if (r1 == nil) != (r2 == nil) {
return false
}
if r1 == nil {
return true
}
return r1.contentType == r2.contentType &&
r1.major == r2.major &&
r1.minor == r2.minor &&
bytes.Compare(r1.payload, r2.payload) == 0
}
type recordReaderTest struct {
in []byte
out []*record
}
var recordReaderTests = []recordReaderTest{
recordReaderTest{nil, nil},
recordReaderTest{fromHex("01"), nil},
recordReaderTest{fromHex("0102"), nil},
recordReaderTest{fromHex("010203"), nil},
recordReaderTest{fromHex("01020300"), nil},
recordReaderTest{fromHex("0102030000"), []*record{&record{1, 2, 3, nil}}},
recordReaderTest{fromHex("01020300000102030000"), []*record{&record{1, 2, 3, nil}, &record{1, 2, 3, nil}}},
recordReaderTest{fromHex("0102030001fe0102030002feff"), []*record{&record{1, 2, 3, []byte{0xfe}}, &record{1, 2, 3, []byte{0xfe, 0xff}}}},
recordReaderTest{fromHex("010203000001020300"), []*record{&record{1, 2, 3, nil}}},
}
func TestRecordReader(t *testing.T) {
for i, test := range recordReaderTests {
buf := bytes.NewBuffer(test.in)
c := make(chan *record)
go recordReader(c, buf)
matchRecordReaderOutput(t, i, test, c)
buf = bytes.NewBuffer(test.in)
buf2 := iotest.OneByteReader(buf)
c = make(chan *record)
go recordReader(c, buf2)
matchRecordReaderOutput(t, i*2, test, c)
}
}
func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) {
for j, r1 := range test.out {
r2 := <-c
if r2 == nil {
t.Errorf("#%d truncated after %d values", i, j)
break
}
if !matchRecord(r1, r2) {
t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1)
}
}
<-c
if !closed(c) {
t.Errorf("#%d: channel didn't close", i)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"fmt"
"hash"
"io"
)
// writerEnableApplicationData is a message which instructs recordWriter to
// start reading and transmitting data from the application data channel.
type writerEnableApplicationData struct{}
// writerChangeCipherSpec updates the encryption and MAC functions and resets
// the sequence count.
type writerChangeCipherSpec struct {
encryptor encryptor
mac hash.Hash
}
// writerSetVersion sets the version number bytes that we included in the
// record header for future records.
type writerSetVersion struct {
major, minor uint8
}
// A recordWriter accepts messages from the handshake processor and
// application data. It writes them to the outgoing connection and blocks on
// writing. It doesn't read from the application data channel until the
// handshake processor has signaled that the handshake is complete.
type recordWriter struct {
writer io.Writer
encryptor encryptor
mac hash.Hash
seqNum uint64
major, minor uint8
shutdown bool
appChan <-chan []byte
controlChan <-chan interface{}
header [13]byte
}
func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) {
w.writer = writer
w.encryptor = nop{}
w.mac = nop{}
w.appChan = appChan
w.controlChan = controlChan
for !w.shutdown {
msg := <-controlChan
if _, ok := msg.(writerEnableApplicationData); ok {
break
}
w.processControlMessage(msg)
}
for !w.shutdown {
// Always process control messages first.
if controlMsg, ok := <-controlChan; ok {
w.processControlMessage(controlMsg)
continue
}
select {
case controlMsg := <-controlChan:
w.processControlMessage(controlMsg)
case appMsg := <-appChan:
w.processAppMessage(appMsg)
}
}
if !closed(appChan) {
go func() {
for _ = range appChan {
}
}()
}
if !closed(controlChan) {
go func() {
for _ = range controlChan {
}
}()
}
}
// fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1.
func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) {
header[0] = uint8(seqNum >> 56)
header[1] = uint8(seqNum >> 48)
header[2] = uint8(seqNum >> 40)
header[3] = uint8(seqNum >> 32)
header[4] = uint8(seqNum >> 24)
header[5] = uint8(seqNum >> 16)
header[6] = uint8(seqNum >> 8)
header[7] = uint8(seqNum)
header[8] = uint8(r.contentType)
header[9] = r.major
header[10] = r.minor
header[11] = uint8(length >> 8)
header[12] = uint8(length)
}
func (w *recordWriter) writeRecord(r *record) {
w.mac.Reset()
fillMACHeader(&w.header, w.seqNum, len(r.payload), r)
w.mac.Write(w.header[0:13])
w.mac.Write(r.payload)
macBytes := w.mac.Sum()
w.encryptor.XORKeyStream(r.payload)
w.encryptor.XORKeyStream(macBytes)
length := len(r.payload) + len(macBytes)
w.header[11] = uint8(length >> 8)
w.header[12] = uint8(length)
w.writer.Write(w.header[8:13])
w.writer.Write(r.payload)
w.writer.Write(macBytes)
w.seqNum++
}
func (w *recordWriter) processControlMessage(controlMsg interface{}) {
if controlMsg == nil {
w.shutdown = true
return
}
switch msg := controlMsg.(type) {
case writerChangeCipherSpec:
w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}})
w.encryptor = msg.encryptor
w.mac = msg.mac
w.seqNum = 0
case writerSetVersion:
w.major = msg.major
w.minor = msg.minor
case alert:
w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}})
case handshakeMessage:
// TODO(agl): marshal may return a slice too large for a single record.
w.writeRecord(&record{recordTypeHandshake, w.major, w.minor, msg.marshal()})
default:
fmt.Printf("processControlMessage: unknown %#v\n", msg)
}
}
func (w *recordWriter) processAppMessage(appMsg []byte) {
if closed(w.appChan) {
w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}})
w.shutdown = true
return
}
var done int
for done < len(appMsg) {
todo := len(appMsg)
if todo > maxTLSPlaintext {
todo = maxTLSPlaintext
}
w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]})
done += todo
}
}
...@@ -6,158 +6,16 @@ ...@@ -6,158 +6,16 @@
package tls package tls
import ( import (
"io"
"os" "os"
"net" "net"
"time"
) )
// A Conn represents a secure connection.
type Conn struct {
net.Conn
writeChan chan<- []byte
readChan <-chan []byte
requestChan chan<- interface{}
readBuf []byte
eof bool
readTimeout, writeTimeout int64
}
func timeout(c chan<- bool, nsecs int64) {
time.Sleep(nsecs)
c <- true
}
func (tls *Conn) Read(p []byte) (int, os.Error) {
if len(tls.readBuf) == 0 {
if tls.eof {
return 0, os.EOF
}
var timeoutChan chan bool
if tls.readTimeout > 0 {
timeoutChan = make(chan bool)
go timeout(timeoutChan, tls.readTimeout)
}
select {
case b := <-tls.readChan:
tls.readBuf = b
case <-timeoutChan:
return 0, os.EAGAIN
}
// TLS distinguishes between orderly closes and truncations. An
// orderly close is represented by a zero length slice.
if closed(tls.readChan) {
return 0, io.ErrUnexpectedEOF
}
if len(tls.readBuf) == 0 {
tls.eof = true
return 0, os.EOF
}
}
n := copy(p, tls.readBuf)
tls.readBuf = tls.readBuf[n:]
return n, nil
}
func (tls *Conn) Write(p []byte) (int, os.Error) {
if tls.eof || closed(tls.readChan) {
return 0, os.EOF
}
var timeoutChan chan bool
if tls.writeTimeout > 0 {
timeoutChan = make(chan bool)
go timeout(timeoutChan, tls.writeTimeout)
}
select {
case tls.writeChan <- p:
case <-timeoutChan:
return 0, os.EAGAIN
}
return len(p), nil
}
func (tls *Conn) Close() os.Error {
close(tls.writeChan)
close(tls.requestChan)
tls.eof = true
return nil
}
func (tls *Conn) SetTimeout(nsec int64) os.Error {
tls.readTimeout = nsec
tls.writeTimeout = nsec
return nil
}
func (tls *Conn) SetReadTimeout(nsec int64) os.Error {
tls.readTimeout = nsec
return nil
}
func (tls *Conn) SetWriteTimeout(nsec int64) os.Error {
tls.writeTimeout = nsec
return nil
}
func (tls *Conn) GetConnectionState() ConnectionState {
replyChan := make(chan ConnectionState)
tls.requestChan <- getConnectionState{replyChan}
return <-replyChan
}
func (tls *Conn) WaitConnectionState() ConnectionState {
replyChan := make(chan ConnectionState)
tls.requestChan <- waitConnectionState{replyChan}
return <-replyChan
}
type handshaker interface {
loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config)
}
// Server establishes a secure connection over the given connection and acts
// as a TLS server.
func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn {
if config == nil {
config = defaultConfig()
}
tls := new(Conn)
tls.Conn = conn
writeChan := make(chan []byte)
readChan := make(chan []byte)
requestChan := make(chan interface{})
tls.writeChan = writeChan
tls.readChan = readChan
tls.requestChan = requestChan
handshakeWriterChan := make(chan interface{})
processorHandshakeChan := make(chan interface{})
handshakeProcessorChan := make(chan interface{})
readerProcessorChan := make(chan *record)
go new(recordWriter).loop(conn, writeChan, handshakeWriterChan)
go recordReader(readerProcessorChan, conn)
go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan)
go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config)
return tls
}
func Server(conn net.Conn, config *Config) *Conn { func Server(conn net.Conn, config *Config) *Conn {
return startTLSGoroutines(conn, new(serverHandshake), config) return &Conn{conn: conn, config: config}
} }
func Client(conn net.Conn, config *Config) *Conn { func Client(conn net.Conn, config *Config) *Conn {
return startTLSGoroutines(conn, new(clientHandshake), config) return &Conn{conn: conn, config: config, isClient: true}
} }
type Listener struct { type Listener struct {
...@@ -180,22 +38,24 @@ func (l *Listener) Addr() net.Addr { return l.listener.Addr() } ...@@ -180,22 +38,24 @@ func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
// NewListener creates a Listener which accepts connections from an inner // NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server. // Listener and wraps each connection with Server.
// The configuration config must be non-nil and must have
// at least one certificate.
func NewListener(listener net.Listener, config *Config) (l *Listener) { func NewListener(listener net.Listener, config *Config) (l *Listener) {
if config == nil {
config = defaultConfig()
}
l = new(Listener) l = new(Listener)
l.listener = listener l.listener = listener
l.config = config l.config = config
return return
} }
func Listen(network, laddr string) (net.Listener, os.Error) { func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
if config == nil || len(config.Certificates) == 0 {
return nil, os.NewError("tls.Listen: no certificates in configuration")
}
l, err := net.Listen(network, laddr) l, err := net.Listen(network, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewListener(l, nil), nil return NewListener(l, config), nil
} }
func Dial(network, laddr, raddr string) (net.Conn, os.Error) { func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment