Commit 93a0b0f3 authored by Todd Neal's avatar Todd Neal

[dev.ssa] cmd/compile: rewrites for constant shifts

Add rewrite rules to optimize constant shifts.

Fixes #10637

Change-Id: I74b724d3e81aeb7098c696d02c050f7fdfd5b523
Reviewed-on: https://go-review.googlesource.com/19106Reviewed-by: default avatarKeith Randall <khr@golang.org>
Run-TryBot: Todd Neal <todd@tneal.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 606a11f4
......@@ -8,6 +8,102 @@
package main
import "fmt"
// testArithRshConst ensures that "const >> const" right shifts correctly perform
// sign extension on the lhs constant
func testArithRshConst() {
wantu := uint64(0x4000000000000000)
if got := arithRshuConst_ssa(); got != wantu {
println("arithRshuConst failed, wanted", wantu, "got", got)
failed = true
}
wants := int64(-0x4000000000000000)
if got := arithRshConst_ssa(); got != wants {
println("arithRshuConst failed, wanted", wants, "got", got)
failed = true
}
}
//go:noinline
func arithRshuConst_ssa() uint64 {
y := uint64(0x8000000000000001)
z := uint64(1)
return uint64(y >> z)
}
//go:noinline
func arithRshConst_ssa() int64 {
y := int64(-0x8000000000000000)
z := uint64(1)
return int64(y >> z)
}
//go:noinline
func arithConstShift_ssa(x int64) int64 {
return x >> 100
}
// testArithConstShift tests that right shift by large constants preserve
// the sign of the input.
func testArithConstShift() {
want := int64(-1)
if got := arithConstShift_ssa(-1); want != got {
println("arithConstShift_ssa(-1) failed, wanted", want, "got", got)
failed = true
}
want = 0
if got := arithConstShift_ssa(1); want != got {
println("arithConstShift_ssa(1) failed, wanted", want, "got", got)
failed = true
}
}
// overflowConstShift_ssa verifes that constant folding for shift
// doesn't wrap (i.e. x << MAX_INT << 1 doesn't get folded to x << 0).
//go:noinline
func overflowConstShift64_ssa(x int64) int64 {
return x << uint64(0xffffffffffffffff) << uint64(1)
}
//go:noinline
func overflowConstShift32_ssa(x int64) int32 {
return int32(x) << uint32(0xffffffff) << uint32(1)
}
//go:noinline
func overflowConstShift16_ssa(x int64) int16 {
return int16(x) << uint16(0xffff) << uint16(1)
}
//go:noinline
func overflowConstShift8_ssa(x int64) int8 {
return int8(x) << uint8(0xff) << uint8(1)
}
func testOverflowConstShift() {
want := int64(0)
for x := int64(-127); x < int64(127); x++ {
got := overflowConstShift64_ssa(x)
if want != got {
fmt.Printf("overflowShift64 failed, wanted %d got %d\n", want, got)
}
got = int64(overflowConstShift32_ssa(x))
if want != got {
fmt.Printf("overflowShift32 failed, wanted %d got %d\n", want, got)
}
got = int64(overflowConstShift16_ssa(x))
if want != got {
fmt.Printf("overflowShift16 failed, wanted %d got %d\n", want, got)
}
got = int64(overflowConstShift8_ssa(x))
if want != got {
fmt.Printf("overflowShift8 failed, wanted %d got %d\n", want, got)
}
}
}
// test64BitConstMult tests that rewrite rules don't fold 64 bit constants
// into multiply instructions.
func test64BitConstMult() {
......@@ -275,6 +371,9 @@ func main() {
testLrot()
testShiftCX()
testSubConst()
testOverflowConstShift()
testArithConstShift()
testArithRshConst()
if failed {
panic("failed")
......
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This program generates a test to verify that the standard arithmetic
// operators properly handle const cases. The test file should be
// generated with a known working version of go.
// launch with `go run arithConstGen.go` a file called arithConst_ssa.go
// will be written into the parent directory containing the tests
package main
import (
"bytes"
"fmt"
"go/format"
"io/ioutil"
"log"
"strings"
"text/template"
)
type op struct {
name, symbol string
}
type szD struct {
name string
sn string
u []uint64
i []int64
}
var szs []szD = []szD{
szD{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0xffffFFFFffffFFFF}},
szD{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF,
-4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}},
szD{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}},
szD{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0,
1, 0x7FFFFFFF}},
szD{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}},
szD{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}},
szD{name: "uint8", sn: "8", u: []uint64{0, 1, 255}},
szD{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}},
}
var ops []op = []op{op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"},
op{"lsh", "<<"}, op{"rsh", ">>"}}
// compute the result of i op j, cast as type t.
func ansU(i, j uint64, t, op string) string {
var ans uint64
switch op {
case "+":
ans = i + j
case "-":
ans = i - j
case "*":
ans = i * j
case "/":
if j != 0 {
ans = i / j
}
case "<<":
ans = i << j
case ">>":
ans = i >> j
}
switch t {
case "uint32":
ans = uint64(uint32(ans))
case "uint16":
ans = uint64(uint16(ans))
case "uint8":
ans = uint64(uint8(ans))
}
return fmt.Sprintf("%d", ans)
}
// compute the result of i op j, cast as type t.
func ansS(i, j int64, t, op string) string {
var ans int64
switch op {
case "+":
ans = i + j
case "-":
ans = i - j
case "*":
ans = i * j
case "/":
if j != 0 {
ans = i / j
}
case "<<":
ans = i << uint64(j)
case ">>":
ans = i >> uint64(j)
}
switch t {
case "int32":
ans = int64(int32(ans))
case "int16":
ans = int64(int16(ans))
case "int8":
ans = int64(int8(ans))
}
return fmt.Sprintf("%d", ans)
}
func main() {
w := new(bytes.Buffer)
fmt.Fprintf(w, "package main;\n")
fmt.Fprintf(w, "import \"fmt\"\n")
fncCnst1, err := template.New("fnc").Parse(
`//go:noinline
func {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa(a {{.Type_}}) {{.Type_}} {
return a {{.Symbol}} {{.Number}}
}
`)
if err != nil {
panic(err)
}
fncCnst2, err := template.New("fnc").Parse(
`//go:noinline
func {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa(a {{.Type_}}) {{.Type_}} {
return {{.Number}} {{.Symbol}} a
}
`)
if err != nil {
panic(err)
}
type fncData struct {
Name, Type_, Symbol, FNumber, Number string
}
for _, s := range szs {
for _, o := range ops {
fd := fncData{o.name, s.name, o.symbol, "", ""}
// unsigned test cases
if len(s.u) > 0 {
for _, i := range s.u {
fd.Number = fmt.Sprintf("%d", i)
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// avoid division by zero
if o.name != "div" || i != 0 {
fncCnst1.Execute(w, fd)
}
fncCnst2.Execute(w, fd)
}
}
// signed test cases
if len(s.i) > 0 {
// don't generate tests for shifts by signed integers
if o.name == "lsh" || o.name == "rsh" {
continue
}
for _, i := range s.i {
fd.Number = fmt.Sprintf("%d", i)
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// avoid division by zero
if o.name != "div" || i != 0 {
fncCnst1.Execute(w, fd)
}
fncCnst2.Execute(w, fd)
}
}
}
}
fmt.Fprintf(w, "var failed bool\n\n")
fmt.Fprintf(w, "func main() {\n\n")
vrf1, _ := template.New("vrf1").Parse(`
if got := {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa({{.Input}}); got != {{.Ans}} {
fmt.Printf("{{.Name}}_{{.Type_}} {{.Number}}{{.Symbol}}{{.Input}} = %d, wanted {{.Ans}}\n",got)
failed = true
}
`)
vrf2, _ := template.New("vrf2").Parse(`
if got := {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa({{.Input}}); got != {{.Ans}} {
fmt.Printf("{{.Name}}_{{.Type_}} {{.Input}}{{.Symbol}}{{.Number}} = %d, wanted {{.Ans}}\n",got)
failed = true
}
`)
type cfncData struct {
Name, Type_, Symbol, FNumber, Number string
Ans, Input string
}
for _, s := range szs {
if len(s.u) > 0 {
for _, o := range ops {
fd := cfncData{o.name, s.name, o.symbol, "", "", "", ""}
for _, i := range s.u {
fd.Number = fmt.Sprintf("%d", i)
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// unsigned
for _, j := range s.u {
if o.name != "div" || j != 0 {
fd.Ans = ansU(i, j, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf1.Execute(w, fd)
if err != nil {
panic(err)
}
}
if o.name != "div" || i != 0 {
fd.Ans = ansU(j, i, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf2.Execute(w, fd)
if err != nil {
panic(err)
}
}
}
}
}
}
// signed
if len(s.i) > 0 {
for _, o := range ops {
// don't generate tests for shifts by signed integers
if o.name == "lsh" || o.name == "rsh" {
continue
}
fd := cfncData{o.name, s.name, o.symbol, "", "", "", ""}
for _, i := range s.i {
fd.Number = fmt.Sprintf("%d", i)
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
for _, j := range s.i {
if o.name != "div" || j != 0 {
fd.Ans = ansS(i, j, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf1.Execute(w, fd)
if err != nil {
panic(err)
}
}
if o.name != "div" || i != 0 {
fd.Ans = ansS(j, i, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf2.Execute(w, fd)
if err != nil {
panic(err)
}
}
}
}
}
}
}
fmt.Fprintf(w, `if failed {
panic("tests failed")
}
`)
fmt.Fprintf(w, "}\n")
// gofmt result
b := w.Bytes()
src, err := format.Source(b)
if err != nil {
fmt.Printf("%s\n", b)
panic(err)
}
// write to file
err = ioutil.WriteFile("../arithConst_ssa.go", src, 0666)
if err != nil {
log.Fatalf("can't write output: %v\n", err)
}
}
......@@ -35,6 +35,19 @@
(Mul32 (Const32 [c]) (Const32 [d])) -> (Const32 [c*d])
(Mul64 (Const64 [c]) (Const64 [d])) -> (Const64 [c*d])
(Lsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c << uint64(d)])
(Rsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c >> uint64(d)])
(Rsh64Ux64 (Const64 [c]) (Const64 [d])) -> (Const64 [int64(uint64(c) >> uint64(d))])
(Lsh32x64 (Const32 [c]) (Const64 [d])) -> (Const32 [int64(int32(c) << uint64(d))])
(Rsh32x64 (Const32 [c]) (Const64 [d])) -> (Const32 [int64(int32(c) >> uint64(d))])
(Rsh32Ux64 (Const32 [c]) (Const64 [d])) -> (Const32 [int64(uint32(c) >> uint64(d))])
(Lsh16x64 (Const16 [c]) (Const64 [d])) -> (Const16 [int64(int16(c) << uint64(d))])
(Rsh16x64 (Const16 [c]) (Const64 [d])) -> (Const16 [int64(int16(c) >> uint64(d))])
(Rsh16Ux64 (Const16 [c]) (Const64 [d])) -> (Const16 [int64(uint16(c) >> uint64(d))])
(Lsh8x64 (Const8 [c]) (Const64 [d])) -> (Const8 [int64(int8(c) << uint64(d))])
(Rsh8x64 (Const8 [c]) (Const64 [d])) -> (Const8 [int64(int8(c) >> uint64(d))])
(Rsh8Ux64 (Const8 [c]) (Const64 [d])) -> (Const8 [int64(uint8(c) >> uint64(d))])
(IsInBounds (Const32 [c]) (Const32 [d])) -> (ConstBool [b2i(inBounds32(c,d))])
(IsInBounds (Const64 [c]) (Const64 [d])) -> (ConstBool [b2i(inBounds64(c,d))])
(IsSliceInBounds (Const32 [c]) (Const32 [d])) -> (ConstBool [b2i(sliceInBounds32(c,d))])
......@@ -79,6 +92,89 @@
(Sub16 x (Const16 <t> [c])) && x.Op != OpConst16 -> (Add16 (Const16 <t> [-c]) x)
(Sub8 x (Const8 <t> [c])) && x.Op != OpConst8 -> (Add8 (Const8 <t> [-c]) x)
// rewrite shifts of 8/16/32 bit consts into 64 bit consts to reduce
// the number of the other rewrite rules for const shifts
(Lsh64x32 <t> x (Const32 [c])) -> (Lsh64x64 x (Const64 <t> [int64(uint32(c))]))
(Lsh64x16 <t> x (Const16 [c])) -> (Lsh64x64 x (Const64 <t> [int64(uint16(c))]))
(Lsh64x8 <t> x (Const8 [c])) -> (Lsh64x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh64x32 <t> x (Const32 [c])) -> (Rsh64x64 x (Const64 <t> [int64(uint32(c))]))
(Rsh64x16 <t> x (Const16 [c])) -> (Rsh64x64 x (Const64 <t> [int64(uint16(c))]))
(Rsh64x8 <t> x (Const8 [c])) -> (Rsh64x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh64Ux32 <t> x (Const32 [c])) -> (Rsh64Ux64 x (Const64 <t> [int64(uint32(c))]))
(Rsh64Ux16 <t> x (Const16 [c])) -> (Rsh64Ux64 x (Const64 <t> [int64(uint16(c))]))
(Rsh64Ux8 <t> x (Const8 [c])) -> (Rsh64Ux64 x (Const64 <t> [int64(uint8(c))]))
(Lsh32x32 <t> x (Const32 [c])) -> (Lsh32x64 x (Const64 <t> [int64(uint32(c))]))
(Lsh32x16 <t> x (Const16 [c])) -> (Lsh32x64 x (Const64 <t> [int64(uint16(c))]))
(Lsh32x8 <t> x (Const8 [c])) -> (Lsh32x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh32x32 <t> x (Const32 [c])) -> (Rsh32x64 x (Const64 <t> [int64(uint32(c))]))
(Rsh32x16 <t> x (Const16 [c])) -> (Rsh32x64 x (Const64 <t> [int64(uint16(c))]))
(Rsh32x8 <t> x (Const8 [c])) -> (Rsh32x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh32Ux32 <t> x (Const32 [c])) -> (Rsh32Ux64 x (Const64 <t> [int64(uint32(c))]))
(Rsh32Ux16 <t> x (Const16 [c])) -> (Rsh32Ux64 x (Const64 <t> [int64(uint16(c))]))
(Rsh32Ux8 <t> x (Const8 [c])) -> (Rsh32Ux64 x (Const64 <t> [int64(uint8(c))]))
(Lsh16x32 <t> x (Const32 [c])) -> (Lsh16x64 x (Const64 <t> [int64(uint32(c))]))
(Lsh16x16 <t> x (Const16 [c])) -> (Lsh16x64 x (Const64 <t> [int64(uint16(c))]))
(Lsh16x8 <t> x (Const8 [c])) -> (Lsh16x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh16x32 <t> x (Const32 [c])) -> (Rsh16x64 x (Const64 <t> [int64(uint32(c))]))
(Rsh16x16 <t> x (Const16 [c])) -> (Rsh16x64 x (Const64 <t> [int64(uint16(c))]))
(Rsh16x8 <t> x (Const8 [c])) -> (Rsh16x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh16Ux32 <t> x (Const32 [c])) -> (Rsh16Ux64 x (Const64 <t> [int64(uint32(c))]))
(Rsh16Ux16 <t> x (Const16 [c])) -> (Rsh16Ux64 x (Const64 <t> [int64(uint16(c))]))
(Rsh16Ux8 <t> x (Const8 [c])) -> (Rsh16Ux64 x (Const64 <t> [int64(uint8(c))]))
(Lsh8x32 <t> x (Const32 [c])) -> (Lsh8x64 x (Const64 <t> [int64(uint32(c))]))
(Lsh8x16 <t> x (Const16 [c])) -> (Lsh8x64 x (Const64 <t> [int64(uint16(c))]))
(Lsh8x8 <t> x (Const8 [c])) -> (Lsh8x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh8x32 <t> x (Const32 [c])) -> (Rsh8x64 x (Const64 <t> [int64(uint32(c))]))
(Rsh8x16 <t> x (Const16 [c])) -> (Rsh8x64 x (Const64 <t> [int64(uint16(c))]))
(Rsh8x8 <t> x (Const8 [c])) -> (Rsh8x64 x (Const64 <t> [int64(uint8(c))]))
(Rsh8Ux32 <t> x (Const32 [c])) -> (Rsh8Ux64 x (Const64 <t> [int64(uint32(c))]))
(Rsh8Ux16 <t> x (Const16 [c])) -> (Rsh8Ux64 x (Const64 <t> [int64(uint16(c))]))
(Rsh8Ux8 <t> x (Const8 [c])) -> (Rsh8Ux64 x (Const64 <t> [int64(uint8(c))]))
// shifts by zero
(Lsh64x64 x (Const64 [0])) -> x
(Rsh64x64 x (Const64 [0])) -> x
(Rsh64Ux64 x (Const64 [0])) -> x
(Lsh32x64 x (Const64 [0])) -> x
(Rsh32x64 x (Const64 [0])) -> x
(Rsh32Ux64 x (Const64 [0])) -> x
(Lsh16x64 x (Const64 [0])) -> x
(Rsh16x64 x (Const64 [0])) -> x
(Rsh16Ux64 x (Const64 [0])) -> x
(Lsh8x64 x (Const64 [0])) -> x
(Rsh8x64 x (Const64 [0])) -> x
(Rsh8Ux64 x (Const64 [0])) -> x
// large left shifts of all values, and right shifts of unsigned values
(Lsh64x64 _ (Const64 [c])) && uint64(c) >= 64 -> (Const64 [0])
(Rsh64Ux64 _ (Const64 [c])) && uint64(c) >= 64 -> (Const64 [0])
(Lsh32x64 _ (Const64 [c])) && uint64(c) >= 32 -> (Const64 [0])
(Rsh32Ux64 _ (Const64 [c])) && uint64(c) >= 32 -> (Const64 [0])
(Lsh16x64 _ (Const64 [c])) && uint64(c) >= 16 -> (Const64 [0])
(Rsh16Ux64 _ (Const64 [c])) && uint64(c) >= 16 -> (Const64 [0])
(Lsh8x64 _ (Const64 [c])) && uint64(c) >= 8 -> (Const64 [0])
(Rsh8Ux64 _ (Const64 [c])) && uint64(c) >= 8 -> (Const64 [0])
// combine const shifts
(Lsh64x64 <t> (Lsh64x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Lsh64x64 x (Const64 <t> [c+d]))
(Lsh32x64 <t> (Lsh32x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Lsh32x64 x (Const64 <t> [c+d]))
(Lsh16x64 <t> (Lsh16x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Lsh16x64 x (Const64 <t> [c+d]))
(Lsh8x64 <t> (Lsh8x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Lsh8x64 x (Const64 <t> [c+d]))
(Rsh64x64 <t> (Rsh64x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh64x64 x (Const64 <t> [c+d]))
(Rsh32x64 <t> (Rsh32x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh32x64 x (Const64 <t> [c+d]))
(Rsh16x64 <t> (Rsh16x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh16x64 x (Const64 <t> [c+d]))
(Rsh8x64 <t> (Rsh8x64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh8x64 x (Const64 <t> [c+d]))
(Rsh64Ux64 <t> (Rsh64Ux64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh64Ux64 x (Const64 <t> [c+d]))
(Rsh32Ux64 <t> (Rsh32Ux64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh32Ux64 x (Const64 <t> [c+d]))
(Rsh16Ux64 <t> (Rsh16Ux64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh16Ux64 x (Const64 <t> [c+d]))
(Rsh8Ux64 <t> (Rsh8Ux64 x (Const64 [c])) (Const64 [d])) && !uaddOvf(c,d) -> (Rsh8Ux64 x (Const64 <t> [c+d]))
// constant comparisons
(Eq64 (Const64 [c]) (Const64 [d])) -> (ConstBool [b2i(int64(c) == int64(d))])
(Eq32 (Const32 [c]) (Const32 [d])) -> (ConstBool [b2i(int32(c) == int32(d))])
......
......@@ -181,6 +181,11 @@ func f2i(f float64) int64 {
return int64(math.Float64bits(f))
}
// uaddOvf returns true if unsigned a+b would overflow.
func uaddOvf(a, b int64) bool {
return uint64(a)+uint64(b) < uint64(a)
}
// DUFFZERO consists of repeated blocks of 4 MOVUPSs + ADD,
// See runtime/mkduff.go.
const (
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment