Commit 4d9dd358 authored by Brian Kessler's avatar Brian Kessler Committed by Brad Fitzpatrick

cmd/compile: add signed divisibility rules

"Division by invariant integers using multiplication" paper
by Granlund and Montgomery contains a method for directly computing
divisibility (x%c == 0 for c constant) by means of the modular inverse.
The method is further elaborated in "Hacker's Delight" by Warren Section 10-17

This general rule can compute divisibilty by one multiplication, and add
and a compare for odd divisors and an additional rotate for even divisors.

To apply the divisibility rule, we must take into account
the rules to rewrite x%c = x-((x/c)*c) and (x/c) for c constant on the first
optimization pass "opt".  This complicates the matching as we want to match
only in the cases where the result of (x/c) is not also needed.
So, we must match on the expanded form of (x/c) in the expression x == c*(x/c)
in the "late opt" pass after common subexpresion elimination.

Note, that if there is an intermediate opt pass introduced in the future we
could simplify these rules by delaying the magic division rewrite to "late opt"
and matching directly on (x/c) in the intermediate opt pass.

On amd64, the divisibility check is 30-45% faster.

name                     old time/op  new time/op  delta`
DivisiblePow2constI64-4  0.83ns ± 1%  0.82ns ± 0%     ~     (p=0.079 n=5+4)
DivisibleconstI64-4      2.68ns ± 1%  1.87ns ± 0%  -30.33%  (p=0.000 n=5+4)
DivisibleWDivconstI64-4  2.69ns ± 1%  2.71ns ± 3%     ~     (p=1.000 n=5+5)
DivisiblePow2constI32-4  1.15ns ± 1%  1.15ns ± 0%     ~     (p=0.238 n=5+4)
DivisibleconstI32-4      2.24ns ± 1%  1.20ns ± 0%  -46.48%  (p=0.016 n=5+4)
DivisibleWDivconstI32-4  2.27ns ± 1%  2.27ns ± 1%     ~     (p=0.683 n=5+5)
DivisiblePow2constI16-4  0.81ns ± 1%  0.82ns ± 1%     ~     (p=0.135 n=5+5)
DivisibleconstI16-4      2.11ns ± 2%  1.20ns ± 1%  -42.99%  (p=0.008 n=5+5)
DivisibleWDivconstI16-4  2.23ns ± 0%  2.27ns ± 2%   +1.79%  (p=0.029 n=4+4)
DivisiblePow2constI8-4   0.81ns ± 1%  0.81ns ± 1%     ~     (p=0.286 n=5+5)
DivisibleconstI8-4       2.13ns ± 3%  1.19ns ± 1%  -43.84%  (p=0.008 n=5+5)
DivisibleWDivconstI8-4   2.23ns ± 1%  2.25ns ± 1%     ~     (p=0.183 n=5+5)

Fixes #30282
Fixes #15806

Change-Id: Id20d78263a4fdfe0509229ae4dfa2fede83fc1d0
Reviewed-on: https://go-review.googlesource.com/c/go/+/173998
Run-TryBot: Brian Kessler <brian.m.kessler@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarKeith Randall <khr@golang.org>
parent e7d08b6f
...@@ -1283,24 +1283,65 @@ func div19_uint64(n uint64) bool { ...@@ -1283,24 +1283,65 @@ func div19_uint64(n uint64) bool {
return n%19 == 0 return n%19 == 0
} }
//go:noinline
func div6_int8(n int8) bool {
return n%6 == 0
}
//go:noinline
func div6_int16(n int16) bool {
return n%6 == 0
}
//go:noinline
func div6_int32(n int32) bool {
return n%6 == 0
}
//go:noinline
func div6_int64(n int64) bool {
return n%6 == 0
}
//go:noinline
func div19_int8(n int8) bool {
return n%19 == 0
}
//go:noinline
func div19_int16(n int16) bool {
return n%19 == 0
}
//go:noinline
func div19_int32(n int32) bool {
return n%19 == 0
}
//go:noinline
func div19_int64(n int64) bool {
return n%19 == 0
}
// testDivisibility confirms that rewrite rules x%c ==0 for c constant are correct. // testDivisibility confirms that rewrite rules x%c ==0 for c constant are correct.
func testDivisibility(t *testing.T) { func testDivisibility(t *testing.T) {
// unsigned tests
// test an even and an odd divisor // test an even and an odd divisor
var six, nineteen uint64 = 6, 19 var sixU, nineteenU uint64 = 6, 19
// test all inputs for uint8, uint16 // test all inputs for uint8, uint16
for i := uint64(0); i <= math.MaxUint16; i++ { for i := uint64(0); i <= math.MaxUint16; i++ {
if i <= math.MaxUint8 { if i <= math.MaxUint8 {
if want, got := uint8(i)%uint8(six) == 0, div6_uint8(uint8(i)); got != want { if want, got := uint8(i)%uint8(sixU) == 0, div6_uint8(uint8(i)); got != want {
t.Errorf("div6_uint8(%d) = %v want %v", i, got, want) t.Errorf("div6_uint8(%d) = %v want %v", i, got, want)
} }
if want, got := uint8(i)%uint8(nineteen) == 0, div19_uint8(uint8(i)); got != want { if want, got := uint8(i)%uint8(nineteenU) == 0, div19_uint8(uint8(i)); got != want {
t.Errorf("div6_uint19(%d) = %v want %v", i, got, want) t.Errorf("div6_uint19(%d) = %v want %v", i, got, want)
} }
} }
if want, got := uint16(i)%uint16(six) == 0, div6_uint16(uint16(i)); got != want { if want, got := uint16(i)%uint16(sixU) == 0, div6_uint16(uint16(i)); got != want {
t.Errorf("div6_uint16(%d) = %v want %v", i, got, want) t.Errorf("div6_uint16(%d) = %v want %v", i, got, want)
} }
if want, got := uint16(i)%uint16(nineteen) == 0, div19_uint16(uint16(i)); got != want { if want, got := uint16(i)%uint16(nineteenU) == 0, div19_uint16(uint16(i)); got != want {
t.Errorf("div19_uint16(%d) = %v want %v", i, got, want) t.Errorf("div19_uint16(%d) = %v want %v", i, got, want)
} }
} }
...@@ -1308,35 +1349,106 @@ func testDivisibility(t *testing.T) { ...@@ -1308,35 +1349,106 @@ func testDivisibility(t *testing.T) {
// spot check inputs for uint32 and uint64 // spot check inputs for uint32 and uint64
xu := []uint64{ xu := []uint64{
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
six, 2 * six, 3 * six, 5 * six, 12345 * six, sixU, 2 * sixU, 3 * sixU, 5 * sixU, 12345 * sixU,
six + 1, 2*six - 5, 3*six + 3, 5*six + 4, 12345*six - 2, sixU + 1, 2*sixU - 5, 3*sixU + 3, 5*sixU + 4, 12345*sixU - 2,
nineteen, 2 * nineteen, 3 * nineteen, 5 * nineteen, 12345 * nineteen, nineteenU, 2 * nineteenU, 3 * nineteenU, 5 * nineteenU, 12345 * nineteenU,
nineteen + 1, 2*nineteen - 5, 3*nineteen + 3, 5*nineteen + 4, 12345*nineteen - 2, nineteenU + 1, 2*nineteenU - 5, 3*nineteenU + 3, 5*nineteenU + 4, 12345*nineteenU - 2,
maxU32, maxU32 - 1, maxU32 - 2, maxU32 - 3, maxU32 - 4, maxU32, maxU32 - 1, maxU32 - 2, maxU32 - 3, maxU32 - 4,
maxU32, maxU32 - 5, maxU32 - 6, maxU32 - 7, maxU32 - 8, maxU32 - 5, maxU32 - 6, maxU32 - 7, maxU32 - 8,
maxU32, maxU32 - 9, maxU32 - 10, maxU32 - 11, maxU32 - 12, maxU32 - 9, maxU32 - 10, maxU32 - 11, maxU32 - 12,
maxU32, maxU32 - 13, maxU32 - 14, maxU32 - 15, maxU32 - 16, maxU32 - 13, maxU32 - 14, maxU32 - 15, maxU32 - 16,
maxU32, maxU32 - 17, maxU32 - 18, maxU32 - 19, maxU32 - 20, maxU32 - 17, maxU32 - 18, maxU32 - 19, maxU32 - 20,
maxU64, maxU64 - 1, maxU64 - 2, maxU64 - 3, maxU64 - 4, maxU64, maxU64 - 1, maxU64 - 2, maxU64 - 3, maxU64 - 4,
maxU64, maxU64 - 5, maxU64 - 6, maxU64 - 7, maxU64 - 8, maxU64 - 5, maxU64 - 6, maxU64 - 7, maxU64 - 8,
maxU64, maxU64 - 9, maxU64 - 10, maxU64 - 11, maxU64 - 12, maxU64 - 9, maxU64 - 10, maxU64 - 11, maxU64 - 12,
maxU64, maxU64 - 13, maxU64 - 14, maxU64 - 15, maxU64 - 16, maxU64 - 13, maxU64 - 14, maxU64 - 15, maxU64 - 16,
maxU64, maxU64 - 17, maxU64 - 18, maxU64 - 19, maxU64 - 20, maxU64 - 17, maxU64 - 18, maxU64 - 19, maxU64 - 20,
} }
for _, x := range xu { for _, x := range xu {
if x <= maxU32 { if x <= maxU32 {
if want, got := uint32(x)%uint32(six) == 0, div6_uint32(uint32(x)); got != want { if want, got := uint32(x)%uint32(sixU) == 0, div6_uint32(uint32(x)); got != want {
t.Errorf("div6_uint32(%d) = %v want %v", x, got, want) t.Errorf("div6_uint32(%d) = %v want %v", x, got, want)
} }
if want, got := uint32(x)%uint32(nineteen) == 0, div19_uint32(uint32(x)); got != want { if want, got := uint32(x)%uint32(nineteenU) == 0, div19_uint32(uint32(x)); got != want {
t.Errorf("div19_uint32(%d) = %v want %v", x, got, want) t.Errorf("div19_uint32(%d) = %v want %v", x, got, want)
} }
} }
if want, got := x%six == 0, div6_uint64(x); got != want { if want, got := x%sixU == 0, div6_uint64(x); got != want {
t.Errorf("div6_uint64(%d) = %v want %v", x, got, want) t.Errorf("div6_uint64(%d) = %v want %v", x, got, want)
} }
if want, got := x%nineteen == 0, div19_uint64(x); got != want { if want, got := x%nineteenU == 0, div19_uint64(x); got != want {
t.Errorf("div19_uint64(%d) = %v want %v", x, got, want) t.Errorf("div19_uint64(%d) = %v want %v", x, got, want)
} }
} }
// signed tests
// test an even and an odd divisor
var sixS, nineteenS int64 = 6, 19
// test all inputs for int8, int16
for i := int64(math.MinInt16); i <= math.MaxInt16; i++ {
if math.MinInt8 <= i && i <= math.MaxInt8 {
if want, got := int8(i)%int8(sixS) == 0, div6_int8(int8(i)); got != want {
t.Errorf("div6_int8(%d) = %v want %v", i, got, want)
}
if want, got := int8(i)%int8(nineteenS) == 0, div19_int8(int8(i)); got != want {
t.Errorf("div6_int19(%d) = %v want %v", i, got, want)
}
}
if want, got := int16(i)%int16(sixS) == 0, div6_int16(int16(i)); got != want {
t.Errorf("div6_int16(%d) = %v want %v", i, got, want)
}
if want, got := int16(i)%int16(nineteenS) == 0, div19_int16(int16(i)); got != want {
t.Errorf("div19_int16(%d) = %v want %v", i, got, want)
}
}
var minI32, maxI32, minI64, maxI64 int64 = math.MinInt32, math.MaxInt32, math.MinInt64, math.MaxInt64
// spot check inputs for int32 and int64
xs := []int64{
0, 1, 2, 3, 4, 5,
-1, -2, -3, -4, -5,
sixS, 2 * sixS, 3 * sixS, 5 * sixS, 12345 * sixS,
sixS + 1, 2*sixS - 5, 3*sixS + 3, 5*sixS + 4, 12345*sixS - 2,
-sixS, -2 * sixS, -3 * sixS, -5 * sixS, -12345 * sixS,
-sixS + 1, -2*sixS - 5, -3*sixS + 3, -5*sixS + 4, -12345*sixS - 2,
nineteenS, 2 * nineteenS, 3 * nineteenS, 5 * nineteenS, 12345 * nineteenS,
nineteenS + 1, 2*nineteenS - 5, 3*nineteenS + 3, 5*nineteenS + 4, 12345*nineteenS - 2,
-nineteenS, -2 * nineteenS, -3 * nineteenS, -5 * nineteenS, -12345 * nineteenS,
-nineteenS + 1, -2*nineteenS - 5, -3*nineteenS + 3, -5*nineteenS + 4, -12345*nineteenS - 2,
minI32, minI32 + 1, minI32 + 2, minI32 + 3, minI32 + 4,
minI32 + 5, minI32 + 6, minI32 + 7, minI32 + 8,
minI32 + 9, minI32 + 10, minI32 + 11, minI32 + 12,
minI32 + 13, minI32 + 14, minI32 + 15, minI32 + 16,
minI32 + 17, minI32 + 18, minI32 + 19, minI32 + 20,
maxI32, maxI32 - 1, maxI32 - 2, maxI32 - 3, maxI32 - 4,
maxI32 - 5, maxI32 - 6, maxI32 - 7, maxI32 - 8,
maxI32 - 9, maxI32 - 10, maxI32 - 11, maxI32 - 12,
maxI32 - 13, maxI32 - 14, maxI32 - 15, maxI32 - 16,
maxI32 - 17, maxI32 - 18, maxI32 - 19, maxI32 - 20,
minI64, minI64 + 1, minI64 + 2, minI64 + 3, minI64 + 4,
minI64 + 5, minI64 + 6, minI64 + 7, minI64 + 8,
minI64 + 9, minI64 + 10, minI64 + 11, minI64 + 12,
minI64 + 13, minI64 + 14, minI64 + 15, minI64 + 16,
minI64 + 17, minI64 + 18, minI64 + 19, minI64 + 20,
maxI64, maxI64 - 1, maxI64 - 2, maxI64 - 3, maxI64 - 4,
maxI64 - 5, maxI64 - 6, maxI64 - 7, maxI64 - 8,
maxI64 - 9, maxI64 - 10, maxI64 - 11, maxI64 - 12,
maxI64 - 13, maxI64 - 14, maxI64 - 15, maxI64 - 16,
maxI64 - 17, maxI64 - 18, maxI64 - 19, maxI64 - 20,
}
for _, x := range xs {
if minI32 <= x && x <= maxI32 {
if want, got := int32(x)%int32(sixS) == 0, div6_int32(int32(x)); got != want {
t.Errorf("div6_int32(%d) = %v want %v", x, got, want)
}
if want, got := int32(x)%int32(nineteenS) == 0, div19_int32(int32(x)); got != want {
t.Errorf("div19_int32(%d) = %v want %v", x, got, want)
}
}
if want, got := x%sixS == 0, div6_int64(x); got != want {
t.Errorf("div6_int64(%d) = %v want %v", x, got, want)
}
if want, got := x%nineteenS == 0, div19_int64(x); got != want {
t.Errorf("div19_int64(%d) = %v want %v", x, got, want)
}
}
} }
...@@ -1174,6 +1174,10 @@ ...@@ -1174,6 +1174,10 @@
(Eq32 (Mod32u <typ.UInt32> (ZeroExt8to32 <typ.UInt32> x) (Const32 <typ.UInt32> [c&0xff])) (Const32 <typ.UInt32> [0])) (Eq32 (Mod32u <typ.UInt32> (ZeroExt8to32 <typ.UInt32> x) (Const32 <typ.UInt32> [c&0xff])) (Const32 <typ.UInt32> [0]))
(Eq16 (Mod16u x (Const16 [c])) (Const16 [0])) && x.Op != OpConst16 && udivisibleOK(16,c) && !hasSmallRotate(config) -> (Eq16 (Mod16u x (Const16 [c])) (Const16 [0])) && x.Op != OpConst16 && udivisibleOK(16,c) && !hasSmallRotate(config) ->
(Eq32 (Mod32u <typ.UInt32> (ZeroExt16to32 <typ.UInt32> x) (Const32 <typ.UInt32> [c&0xffff])) (Const32 <typ.UInt32> [0])) (Eq32 (Mod32u <typ.UInt32> (ZeroExt16to32 <typ.UInt32> x) (Const32 <typ.UInt32> [c&0xffff])) (Const32 <typ.UInt32> [0]))
(Eq8 (Mod8 x (Const8 [c])) (Const8 [0])) && x.Op != OpConst8 && sdivisibleOK(8,c) && !hasSmallRotate(config) ->
(Eq32 (Mod32 <typ.Int32> (SignExt8to32 <typ.Int32> x) (Const32 <typ.Int32> [c])) (Const32 <typ.Int32> [0]))
(Eq16 (Mod16 x (Const16 [c])) (Const16 [0])) && x.Op != OpConst16 && sdivisibleOK(16,c) && !hasSmallRotate(config) ->
(Eq32 (Mod32 <typ.Int32> (SignExt16to32 <typ.Int32> x) (Const32 <typ.Int32> [c])) (Const32 <typ.Int32> [0]))
// Divisibility checks x%c == 0 convert to multiply and rotate. // Divisibility checks x%c == 0 convert to multiply and rotate.
// Note, x%c == 0 is rewritten as x == c*(x/c) during the opt pass // Note, x%c == 0 is rewritten as x == c*(x/c) during the opt pass
...@@ -1184,6 +1188,7 @@ ...@@ -1184,6 +1188,7 @@
// Note that if there were an intermediate opt pass, this rule could be applied // Note that if there were an intermediate opt pass, this rule could be applied
// directly on the Div op and magic division rewrites could be delayed to late opt. // directly on the Div op and magic division rewrites could be delayed to late opt.
// Unsigned divisibility checks convert to multiply and rotate.
(Eq8 x (Mul8 (Const8 [c]) (Eq8 x (Mul8 (Const8 [c])
(Trunc32to8 (Trunc32to8
(Rsh32Ux64 (Rsh32Ux64
...@@ -1489,8 +1494,209 @@ ...@@ -1489,8 +1494,209 @@
(Const64 <typ.UInt64> [int64(udivisible(64,c).max)]) (Const64 <typ.UInt64> [int64(udivisible(64,c).max)])
) )
// Signed divisibility checks convert to multiply, add and rotate.
(Eq8 x (Mul8 (Const8 [c])
(Sub8
(Rsh32x64
mul:(Mul32
(Const32 [m])
(SignExt8to32 x))
(Const64 [s]))
(Rsh32x64
(SignExt8to32 x)
(Const64 [31])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(smagic(8,c).m) && s == 8+smagic(8,c).s
&& x.Op != OpConst8 && sdivisibleOK(8,c)
-> (Leq8U
(RotateLeft8 <typ.UInt8>
(Add8 <typ.UInt8>
(Mul8 <typ.UInt8>
(Const8 <typ.UInt8> [int64(int8(sdivisible(8,c).m))])
x)
(Const8 <typ.UInt8> [int64(int8(sdivisible(8,c).a))])
)
(Const8 <typ.UInt8> [int64(8-sdivisible(8,c).k)])
)
(Const8 <typ.UInt8> [int64(int8(sdivisible(8,c).max))])
)
(Eq16 x (Mul16 (Const16 [c])
(Sub16
(Rsh32x64
mul:(Mul32
(Const32 [m])
(SignExt16to32 x))
(Const64 [s]))
(Rsh32x64
(SignExt16to32 x)
(Const64 [31])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(smagic(16,c).m) && s == 16+smagic(16,c).s
&& x.Op != OpConst16 && sdivisibleOK(16,c)
-> (Leq16U
(RotateLeft16 <typ.UInt16>
(Add16 <typ.UInt16>
(Mul16 <typ.UInt16>
(Const16 <typ.UInt16> [int64(int16(sdivisible(16,c).m))])
x)
(Const16 <typ.UInt16> [int64(int16(sdivisible(16,c).a))])
)
(Const16 <typ.UInt16> [int64(16-sdivisible(16,c).k)])
)
(Const16 <typ.UInt16> [int64(int16(sdivisible(16,c).max))])
)
(Eq32 x (Mul32 (Const32 [c])
(Sub32
(Rsh64x64
mul:(Mul64
(Const64 [m])
(SignExt32to64 x))
(Const64 [s]))
(Rsh64x64
(SignExt32to64 x)
(Const64 [63])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(smagic(32,c).m) && s == 32+smagic(32,c).s
&& x.Op != OpConst32 && sdivisibleOK(32,c)
-> (Leq32U
(RotateLeft32 <typ.UInt32>
(Add32 <typ.UInt32>
(Mul32 <typ.UInt32>
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).m))])
x)
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).a))])
)
(Const32 <typ.UInt32> [int64(32-sdivisible(32,c).k)])
)
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).max))])
)
(Eq32 x (Mul32 (Const32 [c])
(Sub32
(Rsh32x64
mul:(Hmul32
(Const32 [m])
x)
(Const64 [s]))
(Rsh32x64
x
(Const64 [31])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(int32(smagic(32,c).m/2)) && s == smagic(32,c).s-1
&& x.Op != OpConst32 && sdivisibleOK(32,c)
-> (Leq32U
(RotateLeft32 <typ.UInt32>
(Add32 <typ.UInt32>
(Mul32 <typ.UInt32>
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).m))])
x)
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).a))])
)
(Const32 <typ.UInt32> [int64(32-sdivisible(32,c).k)])
)
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).max))])
)
(Eq32 x (Mul32 (Const32 [c])
(Sub32
(Rsh32x64
(Add32
mul:(Hmul32
(Const32 [m])
x)
x)
(Const64 [s]))
(Rsh32x64
x
(Const64 [31])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(int32(smagic(32,c).m)) && s == smagic(32,c).s
&& x.Op != OpConst32 && sdivisibleOK(32,c)
-> (Leq32U
(RotateLeft32 <typ.UInt32>
(Add32 <typ.UInt32>
(Mul32 <typ.UInt32>
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).m))])
x)
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).a))])
)
(Const32 <typ.UInt32> [int64(32-sdivisible(32,c).k)])
)
(Const32 <typ.UInt32> [int64(int32(sdivisible(32,c).max))])
)
(Eq64 x (Mul64 (Const64 [c])
(Sub64
(Rsh64x64
mul:(Hmul64
(Const64 [m])
x)
(Const64 [s]))
(Rsh64x64
x
(Const64 [63])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(smagic(64,c).m/2) && s == smagic(64,c).s-1
&& x.Op != OpConst64 && sdivisibleOK(64,c)
-> (Leq64U
(RotateLeft64 <typ.UInt64>
(Add64 <typ.UInt64>
(Mul64 <typ.UInt64>
(Const64 <typ.UInt64> [int64(sdivisible(64,c).m)])
x)
(Const64 <typ.UInt64> [int64(sdivisible(64,c).a)])
)
(Const64 <typ.UInt64> [int64(64-sdivisible(64,c).k)])
)
(Const64 <typ.UInt64> [int64(sdivisible(64,c).max)])
)
(Eq64 x (Mul64 (Const64 [c])
(Sub64
(Rsh64x64
(Add64
mul:(Hmul64
(Const64 [m])
x)
x)
(Const64 [s]))
(Rsh64x64
x
(Const64 [63])))
)
)
&& v.Block.Func.pass.name != "opt" && mul.Uses == 1
&& m == int64(smagic(64,c).m) && s == smagic(64,c).s
&& x.Op != OpConst64 && sdivisibleOK(64,c)
-> (Leq64U
(RotateLeft64 <typ.UInt64>
(Add64 <typ.UInt64>
(Mul64 <typ.UInt64>
(Const64 <typ.UInt64> [int64(sdivisible(64,c).m)])
x)
(Const64 <typ.UInt64> [int64(sdivisible(64,c).a)])
)
(Const64 <typ.UInt64> [int64(64-sdivisible(64,c).k)])
)
(Const64 <typ.UInt64> [int64(sdivisible(64,c).max)])
)
// Divisibility check for signed integers for power of two constant are simple mask. // Divisibility check for signed integers for power of two constant are simple mask.
// However, we must match against the rewritten n%c == 0 -> n - c*(n/c) == 0 -> n == c *(n/c) // However, we must match against the rewritten n%c == 0 -> n - c*(n/c) == 0 -> n == c*(n/c)
// where n/c contains fixup code to handle signed n. // where n/c contains fixup code to handle signed n.
(Eq8 n (Lsh8x64 (Eq8 n (Lsh8x64
(Rsh8x64 (Rsh8x64
......
...@@ -195,7 +195,7 @@ func smagic(n uint, c int64) smagicData { ...@@ -195,7 +195,7 @@ func smagic(n uint, c int64) smagicData {
// by using the modular inverse with respect to the word size 2^n. // by using the modular inverse with respect to the word size 2^n.
// //
// Given c, compute m such that (c * m) mod 2^n == 1 // Given c, compute m such that (c * m) mod 2^n == 1
// Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*cinv mod 2^n // Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*m mod 2^n
// //
// x can range from 0, c, 2c, 3c, ... ⎣(2^n - 1)/c⎦ * c the maximum multiple // x can range from 0, c, 2c, 3c, ... ⎣(2^n - 1)/c⎦ * c the maximum multiple
// Thus, x*m mod 2^n is 0, 1, 2, 3, ... ⎣(2^n - 1)/c⎦ // Thus, x*m mod 2^n is 0, 1, 2, 3, ... ⎣(2^n - 1)/c⎦
...@@ -285,3 +285,97 @@ func udivisible(n uint, c int64) udivisibleData { ...@@ -285,3 +285,97 @@ func udivisible(n uint, c int64) udivisibleData {
max: max, max: max,
} }
} }
// For signed integers, a similar method follows.
//
// Given c > 1 and odd, compute m such that (c * m) mod 2^n == 1
// Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*m mod 2^n
//
// x can range from ⎡-2^(n-1)/c⎤ * c, ... -c, 0, c, ... ⎣(2^(n-1) - 1)/c⎦ * c
// Thus, x*m mod 2^n is ⎡-2^(n-1)/c⎤, ... -2, -1, 0, 1, 2, ... ⎣(2^(n-1) - 1)/c⎦
//
// So, x is a multiple of c if and only if:
// ⎡-2^(n-1)/c⎤ <= x*m mod 2^n <= ⎣(2^(n-1) - 1)/c⎦
//
// Since c > 1 and odd, this can be simplified by
// ⎡-2^(n-1)/c⎤ == ⎡(-2^(n-1) + 1)/c⎤ == -⎣(2^(n-1) - 1)/c⎦
//
// -⎣(2^(n-1) - 1)/c⎦ <= x*m mod 2^n <= ⎣(2^(n-1) - 1)/c⎦
//
// To extend this to even integers, consider c = d0 * 2^k where d0 is odd.
// We can test whether x is divisible by both d0 and 2^k.
//
// Let m be such that (d0 * m) mod 2^n == 1.
// Let q = x*m mod 2^n. Then c divides x if:
//
// -⎣(2^(n-1) - 1)/d0⎦ <= q <= ⎣(2^(n-1) - 1)/d0⎦ and q ends in at least k 0-bits
//
// To transform this to a single comparison, we use the following theorem (ZRS in Hacker's Delight).
//
// For a >= 0 the following conditions are equivalent:
// 1) -a <= x <= a and x ends in at least k 0-bits
// 2) RotRight(x+a', k) <= ⎣2a'/2^k⎦
//
// Where a' = a & -2^k (a with its right k bits set to zero)
//
// To see that 1 & 2 are equivalent, note that -a <= x <= a is equivalent to
// -a' <= x <= a' if and only if x ends in at least k 0-bits. Adding -a' to each side gives,
// 0 <= x + a' <= 2a' and x + a' ends in at least k 0-bits if and only if x does since a' has
// k 0-bits by definition. We can use theorem ZRU above with x -> x + a' and a -> 2a' giving 1) == 2).
//
// Let m be such that (d0 * m) mod 2^n == 1.
// Let q = x*m mod 2^n.
// Let a' = ⎣(2^(n-1) - 1)/d0⎦ & -2^k
//
// Then the divisibility test is:
//
// RotRight(q+a', k) <= ⎣2a'/2^k⎦
//
// Note that the calculation is performed using unsigned integers.
// Since a' can have n-1 bits, 2a' may have n bits and there is no risk of overflow.
// sdivisibleOK reports whether we should strength reduce a n-bit dividisibilty check by c.
func sdivisibleOK(n uint, c int64) bool {
if c < 0 {
// Doesn't work for negative c.
return false
}
// Doesn't work for 0.
// Don't use it for powers of 2.
return c&(c-1) != 0
}
type sdivisibleData struct {
k int64 // trailingZeros(c)
m uint64 // m * (c>>k) mod 2^n == 1 multiplicative inverse of odd portion modulo 2^n
a uint64 // ⎣(2^(n-1) - 1)/ (c>>k)⎦ & -(1<<k) additive constant
max uint64 // ⎣(2 a) / (1<<k)⎦ max value to for divisibility
}
func sdivisible(n uint, c int64) sdivisibleData {
d := uint64(c)
k := bits.TrailingZeros64(d)
d0 := d >> uint(k) // the odd portion of the divisor
mask := ^uint64(0) >> (64 - n)
// Calculate the multiplicative inverse via Newton's method.
// Quadratic convergence doubles the number of correct bits per iteration.
m := d0 // initial guess correct to 3-bits d0*d0 mod 8 == 1
m = m * (2 - m*d0) // 6-bits
m = m * (2 - m*d0) // 12-bits
m = m * (2 - m*d0) // 24-bits
m = m * (2 - m*d0) // 48-bits
m = m * (2 - m*d0) // 96-bits >= 64-bits
m = m & mask
a := ((mask >> 1) / d0) & -(1 << uint(k))
max := (2 * a) >> uint(k)
return sdivisibleData{
k: int64(k),
m: m,
a: a,
max: max,
}
}
...@@ -184,7 +184,7 @@ func TestMagicSigned(t *testing.T) { ...@@ -184,7 +184,7 @@ func TestMagicSigned(t *testing.T) {
-c - 1, -c, -c + 1, c - 1, c, c + 1, -c - 1, -c, -c + 1, c - 1, c, c + 1,
-2*c - 1, -2 * c, -2*c + 1, 2*c - 1, 2 * c, 2*c + 1, -2*c - 1, -2 * c, -2*c + 1, 2*c - 1, 2 * c, 2*c + 1,
-mul - 1, -mul, -mul + 1, mul - 1, mul, mul + 1, -mul - 1, -mul, -mul + 1, mul - 1, mul, mul + 1,
int64(1)<<n - 1, -int64(1)<<n + 1, int64(1)<<(n-1) - 1, -int64(1) << (n - 1),
} { } {
X := new(big.Int).SetInt64(x) X := new(big.Int).SetInt64(x)
if X.Cmp(Min) < 0 || X.Cmp(Max) > 0 { if X.Cmp(Min) < 0 || X.Cmp(Max) > 0 {
...@@ -303,3 +303,108 @@ func TestDivisibleUnsigned(t *testing.T) { ...@@ -303,3 +303,108 @@ func TestDivisibleUnsigned(t *testing.T) {
} }
} }
} }
func testDivisibleExhaustive(t *testing.T, n uint) {
minI := -int64(1) << (n - 1)
maxI := int64(1) << (n - 1)
for c := int64(1); c < maxI; c++ {
if !sdivisibleOK(n, int64(c)) {
continue
}
k := sdivisible(n, int64(c)).k
m := sdivisible(n, int64(c)).m
a := sdivisible(n, int64(c)).a
max := sdivisible(n, int64(c)).max
mask := ^uint64(0) >> (64 - n)
for i := minI; i < maxI; i++ {
want := i%c == 0
mul := (uint64(i)*m + a) & mask
rot := (mul>>uint(k) | mul<<(n-uint(k))) & mask
got := rot <= max
if want != got {
t.Errorf("signed divisible wrong for %d %% %d == 0: got %v, want %v (k=%d,m=%d,a=%d,max=%d)\n", i, c, got, want, k, m, a, max)
}
}
}
}
func TestDivisibleExhaustive8(t *testing.T) {
testDivisibleExhaustive(t, 8)
}
func TestDivisibleExhaustive16(t *testing.T) {
if testing.Short() {
t.Skip("slow test; skipping")
}
testDivisibleExhaustive(t, 16)
}
func TestDivisibleSigned(t *testing.T) {
One := new(big.Int).SetInt64(1)
for _, n := range [...]uint{8, 16, 32, 64} {
TwoNMinusOne := new(big.Int).Lsh(One, n-1)
Max := new(big.Int).Sub(TwoNMinusOne, One)
Min := new(big.Int).Neg(TwoNMinusOne)
for _, c := range [...]int64{
3,
5,
6,
7,
9,
10,
11,
12,
13,
14,
15,
17,
1<<7 - 1,
1<<7 + 1,
1<<15 - 1,
1<<15 + 1,
1<<31 - 1,
1<<31 + 1,
1<<63 - 1,
} {
if c>>(n-1) != 0 {
continue // not appropriate for the given n.
}
if !sdivisibleOK(n, int64(c)) {
t.Errorf("expected n=%d c=%d to pass\n", n, c)
}
k := sdivisible(n, int64(c)).k
m := sdivisible(n, int64(c)).m
a := sdivisible(n, int64(c)).a
max := sdivisible(n, int64(c)).max
mask := ^uint64(0) >> (64 - n)
C := new(big.Int).SetInt64(c)
// Find largest multiple of c.
Mul := new(big.Int).Div(Max, C)
Mul.Mul(Mul, C)
mul := Mul.Int64()
// Try some input values, mostly around multiples of c.
for _, x := range [...]int64{
-1, 1,
-c - 1, -c, -c + 1, c - 1, c, c + 1,
-2*c - 1, -2 * c, -2*c + 1, 2*c - 1, 2 * c, 2*c + 1,
-mul - 1, -mul, -mul + 1, mul - 1, mul, mul + 1,
int64(1)<<(n-1) - 1, -int64(1) << (n - 1),
} {
X := new(big.Int).SetInt64(x)
if X.Cmp(Min) < 0 || X.Cmp(Max) > 0 {
continue
}
want := x%c == 0
mul := (uint64(x)*m + a) & mask
rot := (mul>>uint(k) | mul<<(n-uint(k))) & mask
got := rot <= max
if want != got {
t.Errorf("signed divisible wrong for %d %% %d == 0: got %v, want %v (k=%d,m=%d,a=%d,max=%d)\n", x, c, got, want, k, m, a, max)
}
}
}
}
}
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -216,14 +216,14 @@ func ConstMods(n1 uint, n2 int) (uint, int) { ...@@ -216,14 +216,14 @@ func ConstMods(n1 uint, n2 int) (uint, int) {
} }
// Check that divisibility checks x%c==0 are converted to MULs and rotates // Check that divisibility checks x%c==0 are converted to MULs and rotates
func Divisible(n uint) (even, odd bool) { func Divisible(n1 uint, n2 int) (bool, bool, bool, bool) {
// amd64:"MOVQ\t[$]-6148914691236517205","IMULQ","ROLQ\t[$]63",-"DIVQ" // amd64:"MOVQ\t[$]-6148914691236517205","IMULQ","ROLQ\t[$]63",-"DIVQ"
// 386:"IMUL3L\t[$]-1431655765","ROLL\t[$]31",-"DIVQ" // 386:"IMUL3L\t[$]-1431655765","ROLL\t[$]31",-"DIVQ"
// arm64:"MOVD\t[$]-6148914691236517205","MUL","ROR",-"DIV" // arm64:"MOVD\t[$]-6148914691236517205","MUL","ROR",-"DIV"
// arm:"MUL","CMP\t[$]715827882",-".*udiv" // arm:"MUL","CMP\t[$]715827882",-".*udiv"
// ppc64:"MULLD","ROTL\t[$]63" // ppc64:"MULLD","ROTL\t[$]63"
// ppc64le:"MULLD","ROTL\t[$]63" // ppc64le:"MULLD","ROTL\t[$]63"
even = n%6 == 0 evenU := n1%6 == 0
// amd64:"MOVQ\t[$]-8737931403336103397","IMULQ",-"ROLQ",-"DIVQ" // amd64:"MOVQ\t[$]-8737931403336103397","IMULQ",-"ROLQ",-"DIVQ"
// 386:"IMUL3L\t[$]678152731",-"ROLL",-"DIVQ" // 386:"IMUL3L\t[$]678152731",-"ROLL",-"DIVQ"
...@@ -231,8 +231,25 @@ func Divisible(n uint) (even, odd bool) { ...@@ -231,8 +231,25 @@ func Divisible(n uint) (even, odd bool) {
// arm:"MUL","CMP\t[$]226050910",-".*udiv" // arm:"MUL","CMP\t[$]226050910",-".*udiv"
// ppc64:"MULLD",-"ROTL" // ppc64:"MULLD",-"ROTL"
// ppc64le:"MULLD",-"ROTL" // ppc64le:"MULLD",-"ROTL"
odd = n%19 == 0 oddU := n1%19 == 0
return
// amd64:"IMULQ","ADD","ROLQ\t[$]63",-"DIVQ"
// 386:"IMUL3L\t[$]-1431655765","ADDL\t[$]715827882","ROLL\t[$]31",-"DIVQ"
// arm64:"MUL","ADD\t[$]3074457345618258602","ROR",-"DIV"
// arm:"MUL","ADD\t[$]715827882",-".*udiv"
// ppc64:"MULLD","ADD","ROTL\t[$]63"
// ppc64le:"MULLD","ADD","ROTL\t[$]63"
evenS := n2%6 == 0
// amd64:"IMULQ","ADD",-"ROLQ",-"DIVQ"
// 386:"IMUL3L\t[$]678152731","ADDL\t[$]113025455",-"ROLL",-"DIVQ"
// arm64:"MUL","ADD\t[$]485440633518672410",-"ROR",-"DIV"
// arm:"MUL","ADD\t[$]113025455",-".*udiv"
// ppc64:"MULLD","ADD",-"ROTL"
// ppc64le:"MULLD","ADD",-"ROTL"
oddS := n2%19 == 0
return evenU, oddU, evenS, oddS
} }
// Check that fix-up code is not generated for divisions where it has been proven that // Check that fix-up code is not generated for divisions where it has been proven that
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment