Commit 360ab50a authored by Adam Langley's avatar Adam Langley

crypto/rsa: add support for precomputing CRT values.

This speeds up private key operations by 3.5x (for a 2048-bit
modulus).

R=golang-dev, r, rsc1
CC=golang-dev
https://golang.org/cl/4348053
parent 9f1394d2
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"hash" "hash"
"io" "io"
"os" "os"
"sync"
) )
var bigZero = big.NewInt(0) var bigZero = big.NewInt(0)
...@@ -92,12 +93,16 @@ type PrivateKey struct { ...@@ -92,12 +93,16 @@ type PrivateKey struct {
PublicKey // public part. PublicKey // public part.
D *big.Int // private exponent D *big.Int // private exponent
P, Q *big.Int // prime factors of N P, Q *big.Int // prime factors of N
rwMutex sync.RWMutex // protects the following
dP, dQ *big.Int // D mod (P-1) (or mod Q-1)
pInv *big.Int // p^-1 mod q
} }
// Validate performs basic sanity checks on the key. // Validate performs basic sanity checks on the key.
// It returns nil if the key is valid, or else an os.Error describing a problem. // It returns nil if the key is valid, or else an os.Error describing a problem.
func (priv PrivateKey) Validate() os.Error { func (priv *PrivateKey) Validate() os.Error {
// Check that p and q are prime. Note that this is just a sanity // Check that p and q are prime. Note that this is just a sanity
// check. Since the random witnesses chosen by ProbablyPrime are // check. Since the random witnesses chosen by ProbablyPrime are
// deterministic, given the candidate number, it's easy for an attack // deterministic, given the candidate number, it's easy for an attack
...@@ -321,6 +326,18 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { ...@@ -321,6 +326,18 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
return x, true return x, true
} }
// precompute performs some calculations that speed up private key operations
// in the future.
func (priv *PrivateKey) precompute() {
priv.dP = new(big.Int).Sub(priv.P, bigOne)
priv.dP.Mod(priv.D, priv.dP)
priv.dQ = new(big.Int).Sub(priv.Q, bigOne)
priv.dQ.Mod(priv.D, priv.dQ)
priv.pInv = new(big.Int).ModInverse(priv.P, priv.Q)
}
// decrypt performs an RSA decryption, resulting in a plaintext integer. If a // decrypt performs an RSA decryption, resulting in a plaintext integer. If a
// random source is given, RSA blinding is used. // random source is given, RSA blinding is used.
func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) { func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) {
...@@ -359,7 +376,35 @@ func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.E ...@@ -359,7 +376,35 @@ func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.E
c.Mod(c, priv.N) c.Mod(c, priv.N)
} }
priv.rwMutex.RLock()
if priv.dP == nil && priv.P != nil {
priv.rwMutex.RUnlock()
priv.rwMutex.Lock()
if priv.dP == nil && priv.P != nil {
priv.precompute()
}
priv.rwMutex.Unlock()
priv.rwMutex.RLock()
}
if priv.dP == nil {
m = new(big.Int).Exp(c, priv.D, priv.N) m = new(big.Int).Exp(c, priv.D, priv.N)
} else {
// We have the precalculated values needed for the CRT.
m = new(big.Int).Exp(c, priv.dP, priv.P)
m2 := new(big.Int).Exp(c, priv.dQ, priv.Q)
m2.Sub(m2, m)
if m2.Sign() < 0 {
m2.Add(m2, priv.Q)
}
m2.Mul(m2, priv.pInv)
m2.Mod(m2, priv.Q)
m2.Mul(m2, priv.P)
m.Add(m, m2)
}
priv.rwMutex.RUnlock()
if ir != nil { if ir != nil {
// Unblind. // Unblind.
......
...@@ -13,13 +13,11 @@ import ( ...@@ -13,13 +13,11 @@ import (
) )
func TestKeyGeneration(t *testing.T) { func TestKeyGeneration(t *testing.T) {
random := rand.Reader
size := 1024 size := 1024
if testing.Short() { if testing.Short() {
size = 128 size = 128
} }
priv, err := GenerateKey(random, size) priv, err := GenerateKey(rand.Reader, size)
if err != nil { if err != nil {
t.Errorf("failed to generate key") t.Errorf("failed to generate key")
} }
...@@ -34,7 +32,7 @@ func TestKeyGeneration(t *testing.T) { ...@@ -34,7 +32,7 @@ func TestKeyGeneration(t *testing.T) {
t.Errorf("got:%v, want:%v (%s)", m2, m, priv) t.Errorf("got:%v, want:%v (%s)", m2, m, priv)
} }
m3, err := decrypt(random, priv, c) m3, err := decrypt(rand.Reader, priv, c)
if err != nil { if err != nil {
t.Errorf("error while decrypting (blind): %s", err) t.Errorf("error while decrypting (blind): %s", err)
} }
...@@ -43,6 +41,36 @@ func TestKeyGeneration(t *testing.T) { ...@@ -43,6 +41,36 @@ func TestKeyGeneration(t *testing.T) {
} }
} }
func fromBase10(base10 string) *big.Int {
i := new(big.Int)
i.SetString(base10, 10)
return i
}
func BenchmarkRSA2048Decrypt(b *testing.B) {
b.StopTimer()
priv := &PrivateKey{
PublicKey: PublicKey{
N: fromBase10("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557"),
E: 3,
},
D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"),
P: fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"),
Q: fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"),
dP: fromBase10("87268836788664481617847742404051836863722219423378226768631448578943872685024917537462119749532582810405513868036231739891184315953381607972502029708143024675916069201337579756507382918142663989251192478689233435812947796979136184463322021762944620586062609951431098934759487528690647618558712630030012716955"),
dQ: fromBase10("72899297073656969051716511768351648616192819692191099774160625992013578003650485325502292431707305715670772966359195693648553210602632895460727878519493400703035165776845049514109078463093675699795894225213735151064504234004099679220621507497799177570231932105698976933783237347664793620527883246410113534019"),
pInv: fromBase10("74869409553139788560900845468611147033712996668881056834763135832685363742570238895177002569942885113085732953539155658733820506625547963252830054212438299203610450637505048191657603373418647673681301519272938658040214027296599301464801590350649336869828810772124696732917293401353635425510209977859621865087"),
}
c := fromBase10("1000")
b.StartTimer()
for i := 0; i < b.N; i++ {
decrypt(nil, priv, c)
}
}
type testEncryptOAEPMessage struct { type testEncryptOAEPMessage struct {
in []byte in []byte
seed []byte seed []byte
...@@ -85,10 +113,12 @@ func TestDecryptOAEP(t *testing.T) { ...@@ -85,10 +113,12 @@ func TestDecryptOAEP(t *testing.T) {
for i, test := range testEncryptOAEPData { for i, test := range testEncryptOAEPData {
n.SetString(test.modulus, 16) n.SetString(test.modulus, 16)
d.SetString(test.d, 16) d.SetString(test.d, 16)
private := PrivateKey{PublicKey{n, test.e}, d, nil, nil} private := new(PrivateKey)
private.PublicKey = PublicKey{n, test.e}
private.D = d
for j, message := range test.msgs { for j, message := range test.msgs {
out, err := DecryptOAEP(sha1, nil, &private, message.out, nil) out, err := DecryptOAEP(sha1, nil, private, message.out, nil)
if err != nil { if err != nil {
t.Errorf("#%d,%d error: %s", i, j, err) t.Errorf("#%d,%d error: %s", i, j, err)
} else if bytes.Compare(out, message.in) != 0 { } else if bytes.Compare(out, message.in) != 0 {
...@@ -96,7 +126,7 @@ func TestDecryptOAEP(t *testing.T) { ...@@ -96,7 +126,7 @@ func TestDecryptOAEP(t *testing.T) {
} }
// Decrypt with blinding. // Decrypt with blinding.
out, err = DecryptOAEP(sha1, random, &private, message.out, nil) out, err = DecryptOAEP(sha1, random, private, message.out, nil)
if err != nil { if err != nil {
t.Errorf("#%d,%d (blind) error: %s", i, j, err) t.Errorf("#%d,%d (blind) error: %s", i, j, err)
} else if bytes.Compare(out, message.in) != 0 { } else if bytes.Compare(out, message.in) != 0 {
......
...@@ -54,20 +54,21 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { ...@@ -54,20 +54,21 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) {
return return
} }
key = &rsa.PrivateKey{ key = new(rsa.PrivateKey)
PublicKey: rsa.PublicKey{ key.PublicKey = rsa.PublicKey{
E: priv.E, E: priv.E,
N: new(big.Int).SetBytes(priv.N.Bytes), N: new(big.Int).SetBytes(priv.N.Bytes),
},
D: new(big.Int).SetBytes(priv.D.Bytes),
P: new(big.Int).SetBytes(priv.P.Bytes),
Q: new(big.Int).SetBytes(priv.Q.Bytes),
} }
key.D = new(big.Int).SetBytes(priv.D.Bytes)
key.P = new(big.Int).SetBytes(priv.P.Bytes)
key.Q = new(big.Int).SetBytes(priv.Q.Bytes)
err = key.Validate() err = key.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return return
} }
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"crypto/rsa" "crypto/rsa"
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"reflect"
"testing" "testing"
"time" "time"
) )
...@@ -22,7 +21,11 @@ func TestParsePKCS1PrivateKey(t *testing.T) { ...@@ -22,7 +21,11 @@ func TestParsePKCS1PrivateKey(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Failed to parse private key: %s", err) t.Errorf("Failed to parse private key: %s", err)
} }
if !reflect.DeepEqual(priv, rsaPrivateKey) { if priv.PublicKey.N.Cmp(rsaPrivateKey.PublicKey.N) != 0 ||
priv.PublicKey.E != rsaPrivateKey.PublicKey.E ||
priv.D.Cmp(rsaPrivateKey.D) != 0 ||
priv.P.Cmp(rsaPrivateKey.P) != 0 ||
priv.Q.Cmp(rsaPrivateKey.Q) != 0 {
t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey) t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey)
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment