Commit 25b040c2 authored by Brian Kessler's avatar Brian Kessler Committed by Robert Griesemer

math/big: recognize z.Mul(x, x) as squaring of x

updates #13745

Multiprecision squaring can be done in a straightforward manner
with about half the multiplications of a basic multiplication
due to the symmetry of the operands.  This change implements
basic squaring for nat types and uses it for Int multiplication
when the same variable is supplied to both arguments of
z.Mul(x, x). This has some overhead to allocate a temporary
variable to hold the cross products, shift them to double and
add them to the diagonal terms.  There is a speed benefit in
the intermediate range when the overhead is neglible and the
asymptotic performance of karatsuba multiplication has not been
reached.

basicSqrThreshold = 20
karatsubaSqrThreshold = 400

Were set by running calibrate_test.go to measure timing differences
between the algorithms.  Benchmarks for squaring:

name           old time/op  new time/op  delta
IntSqr/1-4     51.5ns ±25%  25.1ns ± 7%  -51.38%  (p=0.008 n=5+5)
IntSqr/2-4     79.1ns ± 4%  72.4ns ± 2%   -8.47%  (p=0.008 n=5+5)
IntSqr/3-4      102ns ± 4%    97ns ± 5%     ~     (p=0.056 n=5+5)
IntSqr/5-4      161ns ± 4%   163ns ± 7%     ~     (p=0.952 n=5+5)
IntSqr/8-4      277ns ± 5%   267ns ± 6%     ~     (p=0.087 n=5+5)
IntSqr/10-4     358ns ± 3%   360ns ± 4%     ~     (p=0.730 n=5+5)
IntSqr/20-4    1.07µs ± 3%  1.01µs ± 6%     ~     (p=0.056 n=5+5)
IntSqr/30-4    2.36µs ± 4%  1.72µs ± 2%  -27.03%  (p=0.008 n=5+5)
IntSqr/50-4    5.19µs ± 3%  3.88µs ± 4%  -25.37%  (p=0.008 n=5+5)
IntSqr/80-4    11.3µs ± 4%   8.6µs ± 3%  -23.78%  (p=0.008 n=5+5)
IntSqr/100-4   16.2µs ± 4%  12.8µs ± 3%  -21.49%  (p=0.008 n=5+5)
IntSqr/200-4   50.1µs ± 5%  44.7µs ± 3%  -10.65%  (p=0.008 n=5+5)
IntSqr/300-4    105µs ±11%    95µs ± 3%   -9.50%  (p=0.008 n=5+5)
IntSqr/500-4    231µs ± 5%   227µs ± 2%     ~     (p=0.310 n=5+5)
IntSqr/800-4    496µs ± 9%   459µs ± 3%   -7.40%  (p=0.016 n=5+5)
IntSqr/1000-4   700µs ± 3%   710µs ± 5%     ~     (p=0.841 n=5+5)

Show a speed up of 10-25% in the range where basicSqr is optimal,
improved single word squaring and no significant difference when
the fallback to standard multiplication is used.

Change-Id: Iae2c82ca91cf890823f91e5c83bbe9a2c534b72b
Reviewed-on: https://go-review.googlesource.com/53638Reviewed-by: default avatarRobert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 259f78f0
...@@ -2,13 +2,20 @@ ...@@ -2,13 +2,20 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Calibration used to determine thresholds for using
// different algorithms. Ideally, this would be converted
// to go generate to create thresholds.go
// This file prints execution times for the Mul benchmark // This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be // given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The // used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get // results are somewhat fragile; use repeated runs to get
// a clear picture. // a clear picture.
// Usage: go test -run=TestCalibrate -calibrate // Calculates lower and upper thresholds for when basicSqr
// is faster than standard multiplication.
// Usage: go test -run=TestCalibrate -v -calibrate
package big package big
...@@ -21,6 +28,27 @@ import ( ...@@ -21,6 +28,27 @@ import (
var calibrate = flag.Bool("calibrate", false, "run calibration test") var calibrate = flag.Bool("calibrate", false, "run calibration test")
func TestCalibrate(t *testing.T) {
if *calibrate {
computeKaratsubaThresholds()
// compute basicSqrThreshold where overhead becomes neglible
minSqr := computeSqrThreshold(10, 30, 1, 3)
// compute karatsubaSqrThreshold where karatsuba is faster
maxSqr := computeSqrThreshold(300, 500, 10, 3)
if minSqr != 0 {
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
} else {
fmt.Println("no basicSqrThreshold found")
}
if maxSqr != 0 {
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
} else {
fmt.Println("no karatsubaSqrThreshold found")
}
}
}
func karatsubaLoad(b *testing.B) { func karatsubaLoad(b *testing.B) {
BenchmarkMul(b) BenchmarkMul(b)
} }
...@@ -34,7 +62,7 @@ func measureKaratsuba(th int) time.Duration { ...@@ -34,7 +62,7 @@ func measureKaratsuba(th int) time.Duration {
return time.Duration(res.NsPerOp()) return time.Duration(res.NsPerOp())
} }
func computeThresholds() { func computeKaratsubaThresholds() {
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n") fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
fmt.Printf("(run repeatedly for good results)\n") fmt.Printf("(run repeatedly for good results)\n")
...@@ -81,8 +109,56 @@ func computeThresholds() { ...@@ -81,8 +109,56 @@ func computeThresholds() {
} }
} }
func TestCalibrate(t *testing.T) { func measureBasicSqr(words, nruns int, enable bool) time.Duration {
if *calibrate { // more runs for better statistics
computeThresholds() initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
if enable {
// set thresholds to use basicSqr at this number of words
basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
} else {
// set thresholds to disable basicSqr for any number of words
basicSqrThreshold, karatsubaSqrThreshold = -1, -1
}
var testval int64
for i := 0; i < nruns; i++ {
res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
testval += res.NsPerOp()
}
testval /= int64(nruns)
basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
return time.Duration(testval)
}
func computeSqrThreshold(from, to, step, nruns int) int {
fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
var initPos bool
var threshold int
for i := from; i <= to; i += step {
baseline := measureBasicSqr(i, nruns, false)
testval := measureBasicSqr(i, nruns, true)
pos := baseline > testval
delta := baseline - testval
percent := delta * 100 / baseline
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
if i == from {
initPos = pos
}
if threshold == 0 && pos != initPos {
threshold = i
fmt.Printf(" threshold found")
}
fmt.Println()
}
if threshold != 0 {
fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
} else {
fmt.Printf("Found NO threshold between %d - %d\n", from, to)
} }
return threshold
} }
...@@ -153,6 +153,11 @@ func (z *Int) Mul(x, y *Int) *Int { ...@@ -153,6 +153,11 @@ func (z *Int) Mul(x, y *Int) *Int {
// x * (-y) == -(x * y) // x * (-y) == -(x * y)
// (-x) * y == -(x * y) // (-x) * y == -(x * y)
// (-x) * (-y) == x * y // (-x) * (-y) == x * y
if x == y {
z.abs = z.abs.sqr(x.abs)
z.neg = false
return z
}
z.abs = z.abs.mul(x.abs, y.abs) z.abs = z.abs.mul(x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z return z
......
...@@ -7,6 +7,7 @@ package big ...@@ -7,6 +7,7 @@ package big
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt"
"math/rand" "math/rand"
"strconv" "strconv"
"strings" "strings"
...@@ -1544,3 +1545,24 @@ func BenchmarkSqrt(b *testing.B) { ...@@ -1544,3 +1545,24 @@ func BenchmarkSqrt(b *testing.B) {
t.Sqrt(n) t.Sqrt(n)
} }
} }
func benchmarkIntSqr(b *testing.B, nwords int) {
x := new(Int)
x.abs = rndNat(nwords)
t := new(Int)
b.ResetTimer()
for i := 0; i < b.N; i++ {
t.Mul(x, x)
}
}
func BenchmarkIntSqr(b *testing.B) {
for _, n := range sqrBenchSizes {
if isRaceBuilder && n > 1e3 {
continue
}
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
benchmarkIntSqr(b, n)
})
}
}
...@@ -249,7 +249,7 @@ func karatsubaSub(z, x nat, n int) { ...@@ -249,7 +249,7 @@ func karatsubaSub(z, x nat, n int) {
// Operands that are shorter than karatsubaThreshold are multiplied using // Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm // "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used. // is used.
var karatsubaThreshold int = 40 // computed by calibrate.go var karatsubaThreshold = 40 // computed by calibrate_test.go
// karatsuba multiplies x and y and leaves the result in z. // karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a // Both x and y must have the same length n and n must be a
...@@ -473,6 +473,61 @@ func (z nat) mul(x, y nat) nat { ...@@ -473,6 +473,61 @@ func (z nat) mul(x, y nat) nat {
return z.norm() return z.norm()
} }
// basicSqr sets z = x*x and is asymptotically faster than basicMul
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) >= 2*len(x)
// The (non-normalized) result is placed in z[0 : 2 * len(x)].
func basicSqr(z, x nat) {
n := len(x)
t := make(nat, 2*n) // temporary variable to hold the products
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
for i := 1; i < n; i++ {
d := x[i]
// z collects the squares x[i] * x[i]
z[2*i+1], z[2*i] = mulWW(d, d)
// t collects the products x[i] * x[j] where j < i
t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
}
t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
addVV(z, z, t) // combine the result
}
// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// the Karatsuba algorithm is used.
var basicSqrThreshold = 20 // computed by calibrate_test.go
var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
// z = x*x
func (z nat) sqr(x nat) nat {
n := len(x)
switch {
case n == 0:
return z[:0]
case n == 1:
d := x[0]
z = z.make(2)
z[1], z[0] = mulWW(d, d)
return z.norm()
}
if alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
z = z.make(2 * n)
if n < basicSqrThreshold {
basicMul(z, x, x)
return z.norm()
}
if n < karatsubaSqrThreshold {
basicSqr(z, x)
return z.norm()
}
return z.mul(x, x)
}
// mulRange computes the product of all the unsigned integers in the // mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1. // range [a, b] inclusively. If a > b (empty range), the result is 1.
func (z nat) mulRange(a, b uint64) nat { func (z nat) mulRange(a, b uint64) nat {
......
...@@ -619,3 +619,49 @@ func TestSticky(t *testing.T) { ...@@ -619,3 +619,49 @@ func TestSticky(t *testing.T) {
} }
} }
} }
func testBasicSqr(t *testing.T, x nat) {
got := make(nat, 2*len(x))
want := make(nat, 2*len(x))
basicSqr(got, x)
basicMul(want, x, x)
if got.cmp(want) != 0 {
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
}
}
func TestBasicSqr(t *testing.T) {
for _, a := range prodNN {
if a.x != nil {
testBasicSqr(t, a.x)
}
if a.y != nil {
testBasicSqr(t, a.y)
}
if a.z != nil {
testBasicSqr(t, a.z)
}
}
}
func benchmarkNatSqr(b *testing.B, nwords int) {
x := rndNat(nwords)
var z nat
b.ResetTimer()
for i := 0; i < b.N; i++ {
z.sqr(x)
}
}
var sqrBenchSizes = []int{1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100, 200, 300, 500, 800, 1000}
func BenchmarkNatSqr(b *testing.B) {
for _, n := range sqrBenchSizes {
if isRaceBuilder && n > 1e3 {
continue
}
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
benchmarkNatSqr(b, n)
})
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment