From 2c783dc038250fe602be6436ae9f11ab6ceb8deb Mon Sep 17 00:00:00 2001
From: Alberto Donizetti <alb.donizetti@gmail.com>
Date: Fri, 27 Oct 2017 15:52:14 +0200
Subject: [PATCH] math/big: save one subtraction per iteration in Float.Sqrt
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The Sqrt Newton method computes g(t) = f(t)/f'(t) and then iterates

  t2 = t1 - g(t1)

We can save one operation by including the final subtraction in g(t)
and evaluating the resulting expression symbolically.

For example, for the direct method,

  g(t) = 陆(t虏 - x)/t

and we use 2 multiplications, 1 division and 1 subtraction in g(),
plus 1 final subtraction; but if we compute

  t - g(t) = t - 陆(t虏 - x)/t = 陆(t虏 + x)/t

we only use 2 multiplications, 1 division and 1 addition.

A similar simplification can be done for the inverse method.

name                 old time/op    new time/op    delta
FloatSqrt/64-4          889ns 卤 4%     790ns 卤 1%  -11.19%  (p=0.000 n=8+7)
FloatSqrt/128-4        1.82碌s 卤 0%    1.64碌s 卤 1%  -10.07%  (p=0.001 n=6+8)
FloatSqrt/256-4        3.56碌s 卤 4%    3.10碌s 卤 3%  -12.96%  (p=0.000 n=7+8)
FloatSqrt/1000-4       9.06碌s 卤 3%    8.86碌s 卤 1%   -2.20%  (p=0.001 n=7+7)
FloatSqrt/10000-4       109碌s 卤 1%     107碌s 卤 1%   -1.56%  (p=0.000 n=8+8)
FloatSqrt/100000-4     2.91ms 卤 0%    2.89ms 卤 2%   -0.68%  (p=0.026 n=7+7)
FloatSqrt/1000000-4     237ms 卤 1%     239ms 卤 1%   +0.72%  (p=0.021 n=8+8)

name                 old alloc/op   new alloc/op   delta
FloatSqrt/64-4           448B 卤 0%      416B 卤 0%   -7.14%  (p=0.000 n=8+8)
FloatSqrt/128-4          752B 卤 0%      720B 卤 0%   -4.26%  (p=0.000 n=8+8)
FloatSqrt/256-4        2.05kB 卤 0%    1.34kB 卤 0%  -34.38%  (p=0.000 n=8+8)
FloatSqrt/1000-4       6.91kB 卤 0%    5.09kB 卤 0%  -26.39%  (p=0.000 n=8+8)
FloatSqrt/10000-4      60.5kB 卤 0%    45.9kB 卤 0%  -24.17%  (p=0.000 n=8+8)
FloatSqrt/100000-4      617kB 卤 0%     533kB 卤 0%  -13.57%  (p=0.000 n=8+8)
FloatSqrt/1000000-4    10.3MB 卤 0%     9.2MB 卤 0%  -10.85%  (p=0.000 n=8+8)

name                 old allocs/op  new allocs/op  delta
FloatSqrt/64-4           9.00 卤 0%      9.00 卤 0%     ~     (all equal)
FloatSqrt/128-4          13.0 卤 0%      13.0 卤 0%     ~     (all equal)
FloatSqrt/256-4          20.0 卤 0%      15.0 卤 0%  -25.00%  (p=0.000 n=8+8)
FloatSqrt/1000-4         31.0 卤 0%      24.0 卤 0%  -22.58%  (p=0.000 n=8+8)
FloatSqrt/10000-4        50.0 卤 0%      40.0 卤 0%  -20.00%  (p=0.000 n=8+8)
FloatSqrt/100000-4       76.0 卤 0%      66.0 卤 0%  -13.16%  (p=0.000 n=8+8)
FloatSqrt/1000000-4       146 卤 0%       143 卤 0%   -2.05%  (p=0.000 n=8+8)

Change-Id: I271c00de1ca9740e585bf2af7bcd87b18c1fa68e
Reviewed-on: https://go-review.googlesource.com/73879
Run-TryBot: Alberto Donizetti <alb.donizetti@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
---
 src/math/big/sqrt.go | 37 +++++++++++++++++++------------------
 1 file changed, 19 insertions(+), 18 deletions(-)

diff --git a/src/math/big/sqrt.go b/src/math/big/sqrt.go
index 4f24fdb0f6..d1deb77652 100644
--- a/src/math/big/sqrt.go
+++ b/src/math/big/sqrt.go
@@ -7,10 +7,9 @@ package big
 import "math"
 
 var (
-	nhalf = NewFloat(-0.5)
 	half  = NewFloat(0.5)
-	one   = NewFloat(1.0)
 	two   = NewFloat(2.0)
+	three = NewFloat(3.0)
 )
 
 // Sqrt sets z to the rounded square root of x, and returns it.
@@ -90,14 +89,15 @@ func (z *Float) sqrtDirect(x *Float) {
 	//   f(t) = t虏 - x
 	// then
 	//   g(t) = f(t)/f'(t) = 陆(t虏 - x)/t
+	// and the next guess is given by
+	//   t2 = t - g(t) = 陆(t虏 + x)/t
 	u := new(Float)
-	g := func(t *Float) *Float {
+	ng := func(t *Float) *Float {
 		u.prec = t.prec
-		u.Mul(t, t)    // u = t虏
-		u.Sub(u, x)    //   = t虏 - x
-		u.Mul(half, u) //   = 陆(t虏 - x)
-		u.Quo(u, t)    //   = 陆(t虏 - x)/t
-		return u
+		u.Mul(t, t)        // u = t虏
+		u.Add(u, x)        //   = t虏 + x
+		u.Mul(half, u)     //   = 陆(t虏 + x)
+		return t.Quo(u, t) //   = 陆(t虏 + x)/t
 	}
 
 	xf, _ := x.Float64()
@@ -108,11 +108,11 @@ func (z *Float) sqrtDirect(x *Float) {
 		panic("sqrtDirect: only for z.prec <= 128")
 	case z.prec > 64:
 		sq.prec *= 2
-		sq.Sub(sq, g(sq))
+		sq = ng(sq)
 		fallthrough
 	default:
 		sq.prec *= 2
-		sq.Sub(sq, g(sq))
+		sq = ng(sq)
 	}
 
 	z.Set(sq)
@@ -126,22 +126,23 @@ func (z *Float) sqrtInverse(x *Float) {
 	//   f(t) = 1/t虏 - x
 	// then
 	//   g(t) = f(t)/f'(t) = -陆t(1 - xt虏)
+	// and the next guess is given by
+	//   t2 = t - g(t) = 陆t(3 - xt虏)
 	u := new(Float)
-	g := func(t *Float) *Float {
+	ng := func(t *Float) *Float {
 		u.prec = t.prec
-		u.Mul(t, t)     // u = t虏
-		u.Mul(x, u)     //   = xt虏
-		u.Sub(one, u)   //   = 1 - xt虏
-		u.Mul(nhalf, u) //   = -陆(1 - xt虏)
-		u.Mul(t, u)     //   = -陆t(1 - xt虏)
-		return u
+		u.Mul(t, t)           // u = t虏
+		u.Mul(x, u)           //   = xt虏
+		u.Sub(three, u)       //   = 3 - xt虏
+		u.Mul(t, u)           //   = t(3 - xt虏)
+		return t.Mul(half, u) //   = 陆t(3 - xt虏)
 	}
 
 	xf, _ := x.Float64()
 	sqi := NewFloat(1 / math.Sqrt(xf))
 	for prec := 2 * z.prec; sqi.prec < prec; {
 		sqi.prec *= 2
-		sqi.Sub(sqi, g(sqi))
+		sqi = ng(sqi)
 	}
 	// sqi = 1/鈭歺
 
-- 
2.30.9