Commit 060501dc authored by Alexandru Moșoi's avatar Alexandru Moșoi Committed by Alexandru Moșoi

cmd/compile: constant fold modulo

Fixes #15079

Change-Id: Ib4dd9eab322da39234008e040100e75cb58761b3
Reviewed-on: https://go-review.googlesource.com/21501Reviewed-by: default avatarDavid Chase <drchase@google.com>
Run-TryBot: Alexandru Moșoi <alexandru@mosoi.ro>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 68325b56
......@@ -29,6 +29,11 @@ func b(i uint, j uint) uint {
return i / j
}
//go:noinline
func c(i int) int {
return 7 / (i - i)
}
func main() {
if got := checkDivByZero(func() { b(7, 0) }); !got {
fmt.Printf("expected div by zero for b(7, 0), got no error\n")
......@@ -42,6 +47,10 @@ func main() {
fmt.Printf("expected div by zero for a(4, nil), got no error\n")
failed = true
}
if got := checkDivByZero(func() { c(5) }); !got {
fmt.Printf("expected div by zero for c(5), got no error\n")
failed = true
}
if failed {
panic("tests failed")
......
......@@ -47,7 +47,7 @@ var szs []szD = []szD{
}
var ops []op = []op{op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"},
op{"lsh", "<<"}, op{"rsh", ">>"}}
op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"}}
// compute the result of i op j, cast as type t.
func ansU(i, j uint64, t, op string) string {
......@@ -63,6 +63,10 @@ func ansU(i, j uint64, t, op string) string {
if j != 0 {
ans = i / j
}
case "%":
if j != 0 {
ans = i % j
}
case "<<":
ans = i << j
case ">>":
......@@ -93,6 +97,10 @@ func ansS(i, j int64, t, op string) string {
if j != 0 {
ans = i / j
}
case "%":
if j != 0 {
ans = i % j
}
case "<<":
ans = i << uint64(j)
case ">>":
......@@ -151,7 +159,7 @@ func main() {
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// avoid division by zero
if o.name != "div" || i != 0 {
if o.name != "mod" && o.name != "div" || i != 0 {
fncCnst1.Execute(w, fd)
}
......@@ -170,7 +178,7 @@ func main() {
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// avoid division by zero
if o.name != "div" || i != 0 {
if o.name != "mod" && o.name != "div" || i != 0 {
fncCnst1.Execute(w, fd)
}
fncCnst2.Execute(w, fd)
......@@ -184,14 +192,14 @@ func main() {
vrf1, _ := template.New("vrf1").Parse(`
if got := {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa({{.Input}}); got != {{.Ans}} {
fmt.Printf("{{.Name}}_{{.Type_}} {{.Number}}{{.Symbol}}{{.Input}} = %d, wanted {{.Ans}}\n",got)
fmt.Printf("{{.Name}}_{{.Type_}} {{.Number}}%s{{.Input}} = %d, wanted {{.Ans}}\n", ` + "`{{.Symbol}}`" + `, got)
failed = true
}
`)
vrf2, _ := template.New("vrf2").Parse(`
if got := {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa({{.Input}}); got != {{.Ans}} {
fmt.Printf("{{.Name}}_{{.Type_}} {{.Input}}{{.Symbol}}{{.Number}} = %d, wanted {{.Ans}}\n",got)
fmt.Printf("{{.Name}}_{{.Type_}} {{.Input}}%s{{.Number}} = %d, wanted {{.Ans}}\n", ` + "`{{.Symbol}}`" + `, got)
failed = true
}
`)
......@@ -211,7 +219,7 @@ func main() {
// unsigned
for _, j := range s.u {
if o.name != "div" || j != 0 {
if o.name != "mod" && o.name != "div" || j != 0 {
fd.Ans = ansU(i, j, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf1.Execute(w, fd)
......@@ -220,7 +228,7 @@ func main() {
}
}
if o.name != "div" || i != 0 {
if o.name != "mod" && o.name != "div" || i != 0 {
fd.Ans = ansU(j, i, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf2.Execute(w, fd)
......@@ -247,7 +255,7 @@ func main() {
fd.Number = fmt.Sprintf("%d", i)
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
for _, j := range s.i {
if o.name != "div" || j != 0 {
if o.name != "mod" && o.name != "div" || j != 0 {
fd.Ans = ansS(i, j, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf1.Execute(w, fd)
......@@ -256,7 +264,7 @@ func main() {
}
}
if o.name != "div" || i != 0 {
if o.name != "mod" && o.name != "div" || i != 0 {
fd.Ans = ansS(j, i, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j)
err = vrf2.Execute(w, fd)
......
......@@ -66,6 +66,16 @@
(Const32F [f2i(float64(i2f32(c) * i2f32(d)))])
(Mul64F (Const64F [c]) (Const64F [d])) -> (Const64F [f2i(i2f(c) * i2f(d))])
(Mod8 (Const8 [c]) (Const8 [d])) && d != 0-> (Const8 [int64(int8(c % d))])
(Mod16 (Const16 [c]) (Const16 [d])) && d != 0-> (Const16 [int64(int16(c % d))])
(Mod32 (Const32 [c]) (Const32 [d])) && d != 0-> (Const32 [int64(int32(c % d))])
(Mod64 (Const64 [c]) (Const64 [d])) && d != 0-> (Const64 [c % d])
(Mod8u (Const8 [c]) (Const8 [d])) && d != 0-> (Const8 [int64(uint8(c) % uint8(d))])
(Mod16u (Const16 [c]) (Const16 [d])) && d != 0-> (Const16 [int64(uint16(c) % uint16(d))])
(Mod32u (Const32 [c]) (Const32 [d])) && d != 0-> (Const32 [int64(uint32(c) % uint32(d))])
(Mod64u (Const64 [c]) (Const64 [d])) && d != 0-> (Const64 [int64(uint64(c) % uint64(d))])
(Lsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c << uint64(d)])
(Rsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c >> uint64(d)])
(Rsh64Ux64 (Const64 [c]) (Const64 [d])) -> (Const64 [int64(uint64(c) >> uint64(d))])
......@@ -728,5 +738,5 @@
// A%B = A-(A/B*B).
// This implements % with two * and a bunch of ancillary ops.
// One of the * is free if the user's code also computes A/B.
(Mod64 <t> x (Const64 [c])) && smagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c])))
(Mod64u <t> x (Const64 [c])) && umagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c])))
(Mod64 <t> x (Const64 [c])) && x.Op != OpConst64 && smagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c])))
(Mod64u <t> x (Const64 [c])) && x.Op != OpConst64 && umagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c])))
......@@ -174,10 +174,22 @@ func rewriteValuegeneric(v *Value, config *Config) bool {
return rewriteValuegeneric_OpLsh8x64(v, config)
case OpLsh8x8:
return rewriteValuegeneric_OpLsh8x8(v, config)
case OpMod16:
return rewriteValuegeneric_OpMod16(v, config)
case OpMod16u:
return rewriteValuegeneric_OpMod16u(v, config)
case OpMod32:
return rewriteValuegeneric_OpMod32(v, config)
case OpMod32u:
return rewriteValuegeneric_OpMod32u(v, config)
case OpMod64:
return rewriteValuegeneric_OpMod64(v, config)
case OpMod64u:
return rewriteValuegeneric_OpMod64u(v, config)
case OpMod8:
return rewriteValuegeneric_OpMod8(v, config)
case OpMod8u:
return rewriteValuegeneric_OpMod8u(v, config)
case OpMul16:
return rewriteValuegeneric_OpMul16(v, config)
case OpMul32:
......@@ -4409,11 +4421,136 @@ func rewriteValuegeneric_OpLsh8x8(v *Value, config *Config) bool {
}
return false
}
func rewriteValuegeneric_OpMod16(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod16 (Const16 [c]) (Const16 [d]))
// cond: d != 0
// result: (Const16 [int64(int16(c % d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst16 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst16 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst16)
v.AuxInt = int64(int16(c % d))
return true
}
return false
}
func rewriteValuegeneric_OpMod16u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod16u (Const16 [c]) (Const16 [d]))
// cond: d != 0
// result: (Const16 [int64(uint16(c) % uint16(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst16 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst16 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst16)
v.AuxInt = int64(uint16(c) % uint16(d))
return true
}
return false
}
func rewriteValuegeneric_OpMod32(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod32 (Const32 [c]) (Const32 [d]))
// cond: d != 0
// result: (Const32 [int64(int32(c % d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst32 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst32 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst32)
v.AuxInt = int64(int32(c % d))
return true
}
return false
}
func rewriteValuegeneric_OpMod32u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod32u (Const32 [c]) (Const32 [d]))
// cond: d != 0
// result: (Const32 [int64(uint32(c) % uint32(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst32 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst32 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst32)
v.AuxInt = int64(uint32(c) % uint32(d))
return true
}
return false
}
func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod64 (Const64 [c]) (Const64 [d]))
// cond: d != 0
// result: (Const64 [c % d])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst64 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst64 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst64)
v.AuxInt = c % d
return true
}
// match: (Mod64 <t> x (Const64 [c]))
// cond: smagic64ok(c)
// cond: x.Op != OpConst64 && smagic64ok(c)
// result: (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c])))
for {
t := v.Type
......@@ -4423,7 +4560,7 @@ func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool {
break
}
c := v_1.AuxInt
if !(smagic64ok(c)) {
if !(x.Op != OpConst64 && smagic64ok(c)) {
break
}
v.reset(OpSub64)
......@@ -4446,6 +4583,27 @@ func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool {
func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod64u (Const64 [c]) (Const64 [d]))
// cond: d != 0
// result: (Const64 [int64(uint64(c) % uint64(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst64 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst64 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst64)
v.AuxInt = int64(uint64(c) % uint64(d))
return true
}
// match: (Mod64u <t> n (Const64 [c]))
// cond: isPowerOfTwo(c)
// result: (And64 n (Const64 <t> [c-1]))
......@@ -4468,7 +4626,7 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
return true
}
// match: (Mod64u <t> x (Const64 [c]))
// cond: umagic64ok(c)
// cond: x.Op != OpConst64 && umagic64ok(c)
// result: (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c])))
for {
t := v.Type
......@@ -4478,7 +4636,7 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
break
}
c := v_1.AuxInt
if !(umagic64ok(c)) {
if !(x.Op != OpConst64 && umagic64ok(c)) {
break
}
v.reset(OpSub64)
......@@ -4498,6 +4656,58 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
}
return false
}
func rewriteValuegeneric_OpMod8(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod8 (Const8 [c]) (Const8 [d]))
// cond: d != 0
// result: (Const8 [int64(int8(c % d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst8 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst8 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst8)
v.AuxInt = int64(int8(c % d))
return true
}
return false
}
func rewriteValuegeneric_OpMod8u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod8u (Const8 [c]) (Const8 [d]))
// cond: d != 0
// result: (Const8 [int64(uint8(c) % uint8(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst8 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst8 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst8)
v.AuxInt = int64(uint8(c) % uint8(d))
return true
}
return false
}
func rewriteValuegeneric_OpMul16(v *Value, config *Config) bool {
b := v.Block
_ = b
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment