Commit 8a2553e3 authored by Wade Simmons's avatar Wade Simmons Committed by Brad Fitzpatrick

crypto/rand: only read necessary bytes for Int

We only need to read the number of bytes required to store the value
"max - 1" to generate a random number in the range [0, max).

Before, there was an off-by-one error where an extra byte was read from
the io.Reader for inputs like "256" (right at the boundary for a byte).
There was a similar off-by-one error in the logic for clearing bits and
thus for any input that was a power of 2, there was a 50% chance the
read would continue to be retried as the mask failed to remove a bit.

Fixes #18165.

Change-Id: I548c1368990e23e365591e77980e9086fafb6518
Reviewed-on: https://go-review.googlesource.com/43891Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 9f03e895
...@@ -107,16 +107,23 @@ func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) { ...@@ -107,16 +107,23 @@ func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
if max.Sign() <= 0 { if max.Sign() <= 0 {
panic("crypto/rand: argument to Int is <= 0") panic("crypto/rand: argument to Int is <= 0")
} }
k := (max.BitLen() + 7) / 8 n = new(big.Int)
n.Sub(max, n.SetUint64(1))
// b is the number of bits in the most significant byte of max. // bitLen is the maximum bit length needed to encode a value < max.
b := uint(max.BitLen() % 8) bitLen := n.BitLen()
if bitLen == 0 {
// the only valid result is 0
return
}
// k is the maximum byte length needed to encode a value < max.
k := (bitLen + 7) / 8
// b is the number of bits in the most significant byte of max-1.
b := uint(bitLen % 8)
if b == 0 { if b == 0 {
b = 8 b = 8
} }
bytes := make([]byte, k) bytes := make([]byte, k)
n = new(big.Int)
for { for {
_, err = io.ReadFull(rand, bytes) _, err = io.ReadFull(rand, bytes)
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
package rand_test package rand_test
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"fmt"
"io"
"math/big" "math/big"
mathrand "math/rand" mathrand "math/rand"
"testing" "testing"
...@@ -45,6 +48,56 @@ func TestInt(t *testing.T) { ...@@ -45,6 +48,56 @@ func TestInt(t *testing.T) {
} }
} }
type countingReader struct {
r io.Reader
n int
}
func (r *countingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.n += n
return n, err
}
// Test that Int reads only the necessary number of bytes from the reader for
// max at each bit length
func TestIntReads(t *testing.T) {
for i := 0; i < 32; i++ {
max := int64(1 << uint64(i))
t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
reader := &countingReader{r: rand.Reader}
_, err := rand.Int(reader, big.NewInt(max))
if err != nil {
t.Fatalf("Can't generate random value: %d, %v", max, err)
}
expected := (i + 7) / 8
if reader.n != expected {
t.Errorf("Int(reader, %d) should read %d bytes, but it read: %d", max, expected, reader.n)
}
})
}
}
// Test that Int does not mask out valid return values
func TestIntMask(t *testing.T) {
for max := 1; max <= 256; max++ {
t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
for i := 0; i < max; i++ {
var b bytes.Buffer
b.WriteByte(byte(i))
n, err := rand.Int(&b, big.NewInt(int64(max)))
if err != nil {
t.Fatalf("Can't generate random value: %d, %v", max, err)
}
if n.Int64() != int64(i) {
t.Errorf("Int(reader, %d) should have returned value of %d, but it returned: %v", max, i, n)
}
}
})
}
}
func testIntPanics(t *testing.T, b *big.Int) { func testIntPanics(t *testing.T, b *big.Int) {
defer func() { defer func() {
if err := recover(); err == nil { if err := recover(); err == nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment