Commit 407dbb42 authored by Robert Griesemer's avatar Robert Griesemer

big: improved computation of "karatsuba length" for faster multiplies

This results in an improvement of > 35% for the existing Mul benchmark
using the same karatsuba threshold, and an improvement of > 50% with
a slightly higher threshold (32 instead of 30):

big.BenchmarkMul           500	   6731846 ns/op (old alg.)
big.BenchmarkMul	   500	   4351122 ns/op (new alg.)
big.BenchmarkMul           500	   3133782 ns/op (new alg., new theshold)

Also:
- tweaked calibrate.go, use same benchmark as for Mul benchmark

R=rsc
CC=golang-dev
https://golang.org/cl/1037041
parent f78b09e6
...@@ -2,7 +2,12 @@ ...@@ -2,7 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This file computes the Karatsuba threshold as a "test". // This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get
// a clear picture.
// Usage: gotest -calibrate // Usage: gotest -calibrate
package big package big
...@@ -12,27 +17,13 @@ import ( ...@@ -12,27 +17,13 @@ import (
"fmt" "fmt"
"testing" "testing"
"time" "time"
"unsafe" // for Sizeof
) )
var calibrate = flag.Bool("calibrate", false, "run calibration test") var calibrate = flag.Bool("calibrate", false, "run calibration test")
// makeNumber creates an n-word number 0xffff...ffff // measure returns the time to run f
func makeNumber(n int) *Int {
var w Word
b := make([]byte, n*unsafe.Sizeof(w))
for i := range b {
b[i] = 0xff
}
var x Int
x.SetBytes(b)
return &x
}
// measure returns the time to compute x*x in nanoseconds
func measure(f func()) int64 { func measure(f func()) int64 {
const N = 100 const N = 100
start := time.Nanoseconds() start := time.Nanoseconds()
...@@ -44,48 +35,58 @@ func measure(f func()) int64 { ...@@ -44,48 +35,58 @@ func measure(f func()) int64 {
} }
func computeThreshold(t *testing.T) int { func computeThresholds() {
// use a mix of numbers as work load fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
x := make([]*Int, 20) fmt.Printf("(run repeatedly for good results)\n")
for i := range x {
x[i] = makeNumber(10 * (i + 1))
}
threshold := -1 // determine Tk, the work load execution time using basic multiplication
for n := 8; threshold < 0 || n <= threshold+20; n += 2 { karatsubaThreshold = 1e9 // disable karatsuba
// set work load Tb := measure(benchmarkMulLoad)
f := func() { fmt.Printf("Tb = %dns\n", Tb)
var t Int
for _, x := range x {
t.Mul(x, x)
}
}
karatsubaThreshold = 1e9 // disable karatsuba // thresholds
t1 := measure(f) n := 8 // any lower values for the threshold lead to very slow multiplies
th1 := -1
th2 := -1
var deltaOld int64
for count := -1; count != 0; count-- {
// determine Tk, the work load execution time using Karatsuba multiplication
karatsubaThreshold = n // enable karatsuba karatsubaThreshold = n // enable karatsuba
t2 := measure(f) Tk := measure(benchmarkMulLoad)
c := '<' // improvement over Tb
mark := "" delta := (Tb - Tk) * 100 / Tb
if t1 > t2 {
c = '>' fmt.Printf("n = %3d Tk = %8dns %4d%%", n, Tk, delta)
if threshold < 0 {
threshold = n // determine break-even point
mark = " *" if Tk < Tb && th1 < 0 {
} th1 = n
fmt.Print(" break-even point")
}
// determine diminishing return
if 0 < delta && delta < deltaOld && th2 < 0 {
th2 = n
fmt.Print(" diminishing return")
}
deltaOld = delta
fmt.Println()
// trigger counter
if th1 >= 0 && th2 >= 0 && count < 0 {
count = 20 // this many extra measurements after we got both thresholds
} }
fmt.Printf("%4d: %8d %c %8d%s\n", n, t1, c, t2, mark) n++
} }
return threshold
} }
func TestCalibrate(t *testing.T) { func TestCalibrate(t *testing.T) {
if *calibrate { if *calibrate {
fmt.Printf("Computing Karatsuba threshold\n") computeThresholds()
fmt.Printf("threshold = %d\n", computeThreshold(t))
} }
} }
...@@ -253,7 +253,7 @@ func karatsubaSub(z, x nat, n int) { ...@@ -253,7 +253,7 @@ func karatsubaSub(z, x nat, n int) {
// Operands that are shorter than karatsubaThreshold are multiplied using // Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm // "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used. // is used.
var karatsubaThreshold int = 30 // modified by calibrate.go var karatsubaThreshold int = 32 // computed by calibrate.go
// karatsuba multiplies x and y and leaves the result in z. // karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a // Both x and y must have the same length n and n must be a
...@@ -384,6 +384,20 @@ func max(x, y int) int { ...@@ -384,6 +384,20 @@ func max(x, y int) int {
} }
// karatsubaLen computes an approximation to the maximum k <= n such that
// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
// result is the largest number that can be divided repeatedly by 2 before
// becoming about the value of karatsubaThreshold.
func karatsubaLen(n int) int {
i := uint(0)
for n > karatsubaThreshold {
n >>= 1
i++
}
return n << i
}
func (z nat) mul(x, y nat) nat { func (z nat) mul(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
...@@ -411,17 +425,13 @@ func (z nat) mul(x, y nat) nat { ...@@ -411,17 +425,13 @@ func (z nat) mul(x, y nat) nat {
} }
// m >= n && n >= karatsubaThreshold && n >= 2 // m >= n && n >= karatsubaThreshold && n >= 2
// determine largest k such that // determine Karatsuba length k such that
// //
// x = x1*b + x0 // x = x1*b + x0
// y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
// b = 1<<(_W*k) ("base" of digits xi, yi) // b = 1<<(_W*k) ("base" of digits xi, yi)
// //
// and k is karatsubaThreshold multiplied by a power of 2 k := karatsubaLen(n)
k := max(karatsubaThreshold, 2)
for k*2 <= n {
k *= 2
}
// k <= n // k <= n
// multiply x0 and y0 via Karatsuba // multiply x0 and y0 via Karatsuba
...@@ -972,10 +982,8 @@ func (n nat) probablyPrime(reps int) bool { ...@@ -972,10 +982,8 @@ func (n nat) probablyPrime(reps int) bool {
// We have to exclude these cases because we reject all // We have to exclude these cases because we reject all
// multiples of these numbers below. // multiples of these numbers below.
if n[0] == 3 || n[0] == 5 || n[0] == 7 || n[0] == 11 || switch n[0] {
n[0] == 13 || n[0] == 17 || n[0] == 19 || n[0] == 23 || case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
n[0] == 29 || n[0] == 31 || n[0] == 37 || n[0] == 41 ||
n[0] == 43 || n[0] == 47 || n[0] == 53 {
return true return true
} }
} }
......
...@@ -147,7 +147,7 @@ func TestMulRange(t *testing.T) { ...@@ -147,7 +147,7 @@ func TestMulRange(t *testing.T) {
} }
var mulArg nat var mulArg, mulTmp nat
func init() { func init() {
const n = 1000 const n = 1000
...@@ -158,13 +158,17 @@ func init() { ...@@ -158,13 +158,17 @@ func init() {
} }
func benchmarkMulLoad() {
for j := 1; j <= 10; j++ {
x := mulArg[0 : j*100]
mulTmp.mul(x, x)
}
}
func BenchmarkMul(b *testing.B) { func BenchmarkMul(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var t nat benchmarkMulLoad()
for j := 1; j <= 10; j++ {
x := mulArg[0 : j*100]
t.mul(x, x)
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment