Commit 554ac036 authored by Adam Langley's avatar Adam Langley

crypto: allocate less.

The code in hash functions themselves could write directly into the
output buffer for a savings of about 50ns. But it's a little ugly so I
wasted a copy.

R=bradfitz
CC=golang-dev
https://golang.org/cl/5440111
parent bf59f081
...@@ -49,14 +49,13 @@ func (h *hmac) tmpPad(xor byte) { ...@@ -49,14 +49,13 @@ func (h *hmac) tmpPad(xor byte) {
} }
func (h *hmac) Sum(in []byte) []byte { func (h *hmac) Sum(in []byte) []byte {
sum := h.inner.Sum(nil) origLen := len(in)
in = h.inner.Sum(in)
h.tmpPad(0x5c) h.tmpPad(0x5c)
for i, b := range sum { copy(h.tmp[padSize:], in[origLen:])
h.tmp[padSize+i] = b
}
h.outer.Reset() h.outer.Reset()
h.outer.Write(h.tmp) h.outer.Write(h.tmp)
return h.outer.Sum(in) return h.outer.Sum(in[:origLen])
} }
func (h *hmac) Write(p []byte) (n int, err error) { func (h *hmac) Write(p []byte) (n int, err error) {
......
...@@ -79,8 +79,7 @@ func (d *digest) Write(p []byte) (nn int, err error) { ...@@ -79,8 +79,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
func (d0 *digest) Sum(in []byte) []byte { func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing. // Make a copy of d0 so that caller can keep writing and summing.
d := new(digest) d := *d0
*d = *d0
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
len := d.len len := d.len
...@@ -103,11 +102,13 @@ func (d0 *digest) Sum(in []byte) []byte { ...@@ -103,11 +102,13 @@ func (d0 *digest) Sum(in []byte) []byte {
panic("d.nx != 0") panic("d.nx != 0")
} }
for _, s := range d.s { var digest [Size]byte
in = append(in, byte(s>>0)) for i, s := range d.s {
in = append(in, byte(s>>8)) digest[i*4] = byte(s)
in = append(in, byte(s>>16)) digest[i*4+1] = byte(s >> 8)
in = append(in, byte(s>>24)) digest[i*4+2] = byte(s >> 16)
digest[i*4+3] = byte(s >> 24)
} }
return in
return append(in, digest[:]...)
} }
...@@ -26,6 +26,7 @@ var zero [1]byte ...@@ -26,6 +26,7 @@ var zero [1]byte
// 4880, section 3.7.1.2) using the given hash, input passphrase and salt. // 4880, section 3.7.1.2) using the given hash, input passphrase and salt.
func Salted(out []byte, h hash.Hash, in []byte, salt []byte) { func Salted(out []byte, h hash.Hash, in []byte, salt []byte) {
done := 0 done := 0
var digest []byte
for i := 0; done < len(out); i++ { for i := 0; done < len(out); i++ {
h.Reset() h.Reset()
...@@ -34,7 +35,8 @@ func Salted(out []byte, h hash.Hash, in []byte, salt []byte) { ...@@ -34,7 +35,8 @@ func Salted(out []byte, h hash.Hash, in []byte, salt []byte) {
} }
h.Write(salt) h.Write(salt)
h.Write(in) h.Write(in)
n := copy(out[done:], h.Sum(nil)) digest = h.Sum(digest[:0])
n := copy(out[done:], digest)
done += n done += n
} }
} }
...@@ -52,6 +54,7 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) { ...@@ -52,6 +54,7 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) {
} }
done := 0 done := 0
var digest []byte
for i := 0; done < len(out); i++ { for i := 0; done < len(out); i++ {
h.Reset() h.Reset()
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
...@@ -68,7 +71,8 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) { ...@@ -68,7 +71,8 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) {
written += len(combined) written += len(combined)
} }
} }
n := copy(out[done:], h.Sum(nil)) digest = h.Sum(digest[:0])
n := copy(out[done:], digest)
done += n done += n
} }
} }
......
...@@ -83,8 +83,7 @@ func (d *digest) Write(p []byte) (nn int, err error) { ...@@ -83,8 +83,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
func (d0 *digest) Sum(in []byte) []byte { func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing. // Make a copy of d0 so that caller can keep writing and summing.
d := new(digest) d := *d0
*d = *d0
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
tc := d.tc tc := d.tc
...@@ -107,11 +106,13 @@ func (d0 *digest) Sum(in []byte) []byte { ...@@ -107,11 +106,13 @@ func (d0 *digest) Sum(in []byte) []byte {
panic("d.nx != 0") panic("d.nx != 0")
} }
for _, s := range d.s { var digest [Size]byte
in = append(in, byte(s)) for i, s := range d.s {
in = append(in, byte(s>>8)) digest[i*4] = byte(s)
in = append(in, byte(s>>16)) digest[i*4+1] = byte(s >> 8)
in = append(in, byte(s>>24)) digest[i*4+2] = byte(s >> 16)
digest[i*4+3] = byte(s >> 24)
} }
return in
return append(in, digest[:]...)
} }
...@@ -189,12 +189,13 @@ func incCounter(c *[4]byte) { ...@@ -189,12 +189,13 @@ func incCounter(c *[4]byte) {
// specified in PKCS#1 v2.1. // specified in PKCS#1 v2.1.
func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
var counter [4]byte var counter [4]byte
var digest []byte
done := 0 done := 0
for done < len(out) { for done < len(out) {
hash.Write(seed) hash.Write(seed)
hash.Write(counter[0:4]) hash.Write(counter[0:4])
digest := hash.Sum(nil) digest = hash.Sum(digest[:0])
hash.Reset() hash.Reset()
for i := 0; i < len(digest) && done < len(out); i++ { for i := 0; i < len(digest) && done < len(out); i++ {
......
...@@ -81,8 +81,7 @@ func (d *digest) Write(p []byte) (nn int, err error) { ...@@ -81,8 +81,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
func (d0 *digest) Sum(in []byte) []byte { func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing. // Make a copy of d0 so that caller can keep writing and summing.
d := new(digest) d := *d0
*d = *d0
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
len := d.len len := d.len
...@@ -105,11 +104,13 @@ func (d0 *digest) Sum(in []byte) []byte { ...@@ -105,11 +104,13 @@ func (d0 *digest) Sum(in []byte) []byte {
panic("d.nx != 0") panic("d.nx != 0")
} }
for _, s := range d.h { var digest [Size]byte
in = append(in, byte(s>>24)) for i, s := range d.h {
in = append(in, byte(s>>16)) digest[i*4] = byte(s >> 24)
in = append(in, byte(s>>8)) digest[i*4+1] = byte(s >> 16)
in = append(in, byte(s)) digest[i*4+2] = byte(s >> 8)
digest[i*4+3] = byte(s)
} }
return in
return append(in, digest[:]...)
} }
...@@ -125,8 +125,7 @@ func (d *digest) Write(p []byte) (nn int, err error) { ...@@ -125,8 +125,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
func (d0 *digest) Sum(in []byte) []byte { func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing. // Make a copy of d0 so that caller can keep writing and summing.
d := new(digest) d := *d0
*d = *d0
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
len := d.len len := d.len
...@@ -150,14 +149,19 @@ func (d0 *digest) Sum(in []byte) []byte { ...@@ -150,14 +149,19 @@ func (d0 *digest) Sum(in []byte) []byte {
} }
h := d.h[:] h := d.h[:]
size := Size
if d.is224 { if d.is224 {
h = d.h[:7] h = d.h[:7]
size = Size224
} }
for _, s := range h {
in = append(in, byte(s>>24)) var digest [Size]byte
in = append(in, byte(s>>16)) for i, s := range h {
in = append(in, byte(s>>8)) digest[i*4] = byte(s >> 24)
in = append(in, byte(s)) digest[i*4+1] = byte(s >> 16)
digest[i*4+2] = byte(s >> 8)
digest[i*4+3] = byte(s)
} }
return in
return append(in, digest[:size]...)
} }
...@@ -150,18 +150,23 @@ func (d0 *digest) Sum(in []byte) []byte { ...@@ -150,18 +150,23 @@ func (d0 *digest) Sum(in []byte) []byte {
} }
h := d.h[:] h := d.h[:]
size := Size
if d.is384 { if d.is384 {
h = d.h[:6] h = d.h[:6]
size = Size384
} }
for _, s := range h {
in = append(in, byte(s>>56)) var digest [Size]byte
in = append(in, byte(s>>48)) for i, s := range h {
in = append(in, byte(s>>40)) digest[i*8] = byte(s >> 56)
in = append(in, byte(s>>32)) digest[i*8+1] = byte(s >> 48)
in = append(in, byte(s>>24)) digest[i*8+2] = byte(s >> 40)
in = append(in, byte(s>>16)) digest[i*8+3] = byte(s >> 32)
in = append(in, byte(s>>8)) digest[i*8+4] = byte(s >> 24)
in = append(in, byte(s)) digest[i*8+5] = byte(s >> 16)
digest[i*8+6] = byte(s >> 8)
digest[i*8+7] = byte(s)
} }
return in
return append(in, digest[:size]...)
} }
...@@ -96,7 +96,7 @@ func macSHA1(version uint16, key []byte) macFunction { ...@@ -96,7 +96,7 @@ func macSHA1(version uint16, key []byte) macFunction {
type macFunction interface { type macFunction interface {
Size() int Size() int
MAC(seq, data []byte) []byte MAC(digestBuf, seq, data []byte) []byte
} }
// ssl30MAC implements the SSLv3 MAC function, as defined in // ssl30MAC implements the SSLv3 MAC function, as defined in
...@@ -114,7 +114,7 @@ var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0 ...@@ -114,7 +114,7 @@ var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0
var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c} var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c}
func (s ssl30MAC) MAC(seq, record []byte) []byte { func (s ssl30MAC) MAC(digestBuf, seq, record []byte) []byte {
padLength := 48 padLength := 48
if s.h.Size() == 20 { if s.h.Size() == 20 {
padLength = 40 padLength = 40
...@@ -127,13 +127,13 @@ func (s ssl30MAC) MAC(seq, record []byte) []byte { ...@@ -127,13 +127,13 @@ func (s ssl30MAC) MAC(seq, record []byte) []byte {
s.h.Write(record[:1]) s.h.Write(record[:1])
s.h.Write(record[3:5]) s.h.Write(record[3:5])
s.h.Write(record[recordHeaderLen:]) s.h.Write(record[recordHeaderLen:])
digest := s.h.Sum(nil) digestBuf = s.h.Sum(digestBuf[:0])
s.h.Reset() s.h.Reset()
s.h.Write(s.key) s.h.Write(s.key)
s.h.Write(ssl30Pad2[:padLength]) s.h.Write(ssl30Pad2[:padLength])
s.h.Write(digest) s.h.Write(digestBuf)
return s.h.Sum(nil) return s.h.Sum(digestBuf[:0])
} }
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3. // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3.
...@@ -145,11 +145,11 @@ func (s tls10MAC) Size() int { ...@@ -145,11 +145,11 @@ func (s tls10MAC) Size() int {
return s.h.Size() return s.h.Size()
} }
func (s tls10MAC) MAC(seq, record []byte) []byte { func (s tls10MAC) MAC(digestBuf, seq, record []byte) []byte {
s.h.Reset() s.h.Reset()
s.h.Write(seq) s.h.Write(seq)
s.h.Write(record) s.h.Write(record)
return s.h.Sum(nil) return s.h.Sum(digestBuf[:0])
} }
func rsaKA() keyAgreement { func rsaKA() keyAgreement {
......
...@@ -118,6 +118,9 @@ type halfConn struct { ...@@ -118,6 +118,9 @@ type halfConn struct {
nextCipher interface{} // next encryption state nextCipher interface{} // next encryption state
nextMac macFunction // next MAC algorithm nextMac macFunction // next MAC algorithm
// used to save allocating a new buffer for each MAC.
inDigestBuf, outDigestBuf []byte
} }
// prepareCipherSpec sets the encryption and MAC states // prepareCipherSpec sets the encryption and MAC states
...@@ -280,12 +283,13 @@ func (hc *halfConn) decrypt(b *block) (bool, alert) { ...@@ -280,12 +283,13 @@ func (hc *halfConn) decrypt(b *block) (bool, alert) {
b.data[4] = byte(n) b.data[4] = byte(n)
b.resize(recordHeaderLen + n) b.resize(recordHeaderLen + n)
remoteMAC := payload[n:] remoteMAC := payload[n:]
localMAC := hc.mac.MAC(hc.seq[0:], b.data) localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data)
hc.incSeq() hc.incSeq()
if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
return false, alertBadRecordMAC return false, alertBadRecordMAC
} }
hc.inDigestBuf = localMAC
} }
return true, 0 return true, 0
...@@ -312,12 +316,13 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) { ...@@ -312,12 +316,13 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
func (hc *halfConn) encrypt(b *block) (bool, alert) { func (hc *halfConn) encrypt(b *block) (bool, alert) {
// mac // mac
if hc.mac != nil { if hc.mac != nil {
mac := hc.mac.MAC(hc.seq[0:], b.data) mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data)
hc.incSeq() hc.incSeq()
n := len(b.data) n := len(b.data)
b.resize(n + len(mac)) b.resize(n + len(mac))
copy(b.data[n:], mac) copy(b.data[n:], mac)
hc.outDigestBuf = mac
} }
payload := b.data[recordHeaderLen:] payload := b.data[recordHeaderLen:]
......
...@@ -231,10 +231,10 @@ func (c *Conn) clientHandshake() error { ...@@ -231,10 +231,10 @@ func (c *Conn) clientHandshake() error {
if cert != nil { if cert != nil {
certVerify := new(certificateVerifyMsg) certVerify := new(certificateVerifyMsg)
var digest [36]byte digest := make([]byte, 0, 36)
copy(digest[0:16], finishedHash.serverMD5.Sum(nil)) digest = finishedHash.serverMD5.Sum(digest)
copy(digest[16:36], finishedHash.serverSHA1.Sum(nil)) digest = finishedHash.serverSHA1.Sum(digest)
signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, crypto.MD5SHA1, digest[0:]) signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, crypto.MD5SHA1, digest)
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
......
...@@ -234,9 +234,9 @@ FindCipherSuite: ...@@ -234,9 +234,9 @@ FindCipherSuite:
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
} }
digest := make([]byte, 36) digest := make([]byte, 0, 36)
copy(digest[0:16], finishedHash.serverMD5.Sum(nil)) digest = finishedHash.serverMD5.Sum(digest)
copy(digest[16:36], finishedHash.serverSHA1.Sum(nil)) digest = finishedHash.serverSHA1.Sum(digest)
err = rsa.VerifyPKCS1v15(pub, crypto.MD5SHA1, digest, certVerify.signature) err = rsa.VerifyPKCS1v15(pub, crypto.MD5SHA1, digest, certVerify.signature)
if err != nil { if err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment