Commit 47c9e139 authored by Keith Randall's avatar Keith Randall

cmd/compile: extend prove pass to handle constant comparisons

Find comparisons to constants and propagate that information
down the dominator tree.  Use it to resolve other constant
comparisons on the same variable.

So if we know x >= 7, then a x > 4 condition must return true.

This change allows us to use "_ = b[7]" hints to eliminate bounds checks.

Fixes #14900

Change-Id: Idbf230bd5b7da43de3ecb48706e21cf01bf812f7
Reviewed-on: https://go-review.googlesource.com/21008Reviewed-by: default avatarAlexandru Moșoi <alexandru@mosoi.ro>
parent f5bd3556
......@@ -4,6 +4,8 @@
package ssa
import "math"
type branch int
const (
......@@ -66,30 +68,109 @@ type fact struct {
r relation
}
// a limit records known upper and lower bounds for a value.
type limit struct {
min, max int64 // min <= value <= max, signed
umin, umax uint64 // umin <= value <= umax, unsigned
}
var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
// a limitFact is a limit known for a particular value.
type limitFact struct {
vid ID
limit limit
}
// factsTable keeps track of relations between pairs of values.
type factsTable struct {
facts map[pair]relation // current known set of relation
stack []fact // previous sets of relations
// known lower and upper bounds on individual values.
limits map[ID]limit
limitStack []limitFact // previous entries
}
// checkpointFact is an invalid value used for checkpointing
// and restoring factsTable.
var checkpointFact = fact{}
var checkpointBound = limitFact{}
func newFactsTable() *factsTable {
ft := &factsTable{}
ft.facts = make(map[pair]relation)
ft.stack = make([]fact, 4)
ft.limits = make(map[ID]limit)
ft.limitStack = make([]limitFact, 4)
return ft
}
// get returns the known possible relations between v and w.
// If v and w are not in the map it returns lt|eq|gt, i.e. any order.
func (ft *factsTable) get(v, w *Value, d domain) relation {
if v.isGenericIntConst() || w.isGenericIntConst() {
reversed := false
if v.isGenericIntConst() {
v, w = w, v
reversed = true
}
r := lt | eq | gt
lim, ok := ft.limits[v.ID]
if !ok {
return r
}
c := w.AuxInt
switch d {
case signed:
switch {
case c < lim.min:
r = gt
case c > lim.max:
r = lt
case c == lim.min && c == lim.max:
r = eq
case c == lim.min:
r = gt | eq
case c == lim.max:
r = lt | eq
}
case unsigned:
// TODO: also use signed data if lim.min >= 0?
var uc uint64
switch w.Op {
case OpConst64:
uc = uint64(c)
case OpConst32:
uc = uint64(uint32(c))
case OpConst16:
uc = uint64(uint16(c))
case OpConst8:
uc = uint64(uint8(c))
}
switch {
case uc < lim.umin:
r = gt
case uc > lim.umax:
r = lt
case uc == lim.umin && uc == lim.umax:
r = eq
case uc == lim.umin:
r = gt | eq
case uc == lim.umax:
r = lt | eq
}
}
if reversed {
return reverseBits[r]
}
return r
}
reversed := false
if lessByID(w, v) {
v, w = w, v
reversed = true
reversed = !reversed
}
p := pair{v, w, d}
......@@ -120,12 +201,106 @@ func (ft *factsTable) update(v, w *Value, d domain, r relation) {
oldR := ft.get(v, w, d)
ft.stack = append(ft.stack, fact{p, oldR})
ft.facts[p] = oldR & r
// Extract bounds when comparing against constants
if v.isGenericIntConst() {
v, w = w, v
r = reverseBits[r]
}
if v != nil && w.isGenericIntConst() {
c := w.AuxInt
// Note: all the +1/-1 below could overflow/underflow. Either will
// still generate correct results, it will just lead to imprecision.
// In fact if there is overflow/underflow, the corresponding
// code is unreachable because the known range is outside the range
// of the value's type.
old, ok := ft.limits[v.ID]
if !ok {
old = noLimit
}
lim := old
// Update lim with the new information we know.
switch d {
case signed:
switch r {
case lt:
if c-1 < lim.max {
lim.max = c - 1
}
case lt | eq:
if c < lim.max {
lim.max = c
}
case gt | eq:
if c > lim.min {
lim.min = c
}
case gt:
if c+1 > lim.min {
lim.min = c + 1
}
case lt | gt:
if c == lim.min {
lim.min++
}
if c == lim.max {
lim.max--
}
case eq:
lim.min = c
lim.max = c
}
case unsigned:
var uc uint64
switch w.Op {
case OpConst64:
uc = uint64(c)
case OpConst32:
uc = uint64(uint32(c))
case OpConst16:
uc = uint64(uint16(c))
case OpConst8:
uc = uint64(uint8(c))
}
switch r {
case lt:
if uc-1 < lim.umax {
lim.umax = uc - 1
}
case lt | eq:
if uc < lim.umax {
lim.umax = uc
}
case gt | eq:
if uc > lim.umin {
lim.umin = uc
}
case gt:
if uc+1 > lim.umin {
lim.umin = uc + 1
}
case lt | gt:
if uc == lim.umin {
lim.umin++
}
if uc == lim.umax {
lim.umax--
}
case eq:
lim.umin = uc
lim.umax = uc
}
}
ft.limitStack = append(ft.limitStack, limitFact{v.ID, old})
ft.limits[v.ID] = lim
}
}
// checkpoint saves the current state of known relations.
// Called when descending on a branch.
func (ft *factsTable) checkpoint() {
ft.stack = append(ft.stack, checkpointFact)
ft.limitStack = append(ft.limitStack, checkpointBound)
}
// restore restores known relation to the state just
......@@ -144,6 +319,18 @@ func (ft *factsTable) restore() {
ft.facts[old.p] = old.r
}
}
for {
old := ft.limitStack[len(ft.limitStack)-1]
ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
if old.vid == 0 { // checkpointBound
break
}
if old.limit == noLimit {
delete(ft.limits, old.vid)
} else {
ft.limits[old.vid] = old.limit
}
}
}
func lessByID(v, w *Value) bool {
......@@ -421,6 +608,7 @@ func simplifyBlock(ft *factsTable, b *Block) branch {
// to the upper bound than this is proven. Most useful in cases such as:
// if len(a) <= 1 { return }
// do something with a[1]
// TODO: use constant bounds to do isNonNegative.
if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && isNonNegative(c.Args[0]) {
m := ft.get(a0, a1, signed)
if m != 0 && tr.r&m == m {
......
......@@ -218,6 +218,11 @@ func (v *Value) Unimplementedf(msg string, args ...interface{}) {
v.Block.Func.Config.Unimplementedf(v.Line, msg, args...)
}
// isGenericIntConst returns whether v is a generic integer constant.
func (v *Value) isGenericIntConst() bool {
return v != nil && (v.Op == OpConst64 || v.Op == OpConst32 || v.Op == OpConst16 || v.Op == OpConst8)
}
// ExternSymbol is an aux value that encodes a variable's
// constant offset from the static base pointer.
type ExternSymbol struct {
......
......@@ -49,23 +49,23 @@ var BigEndian bigEndian
type littleEndian struct{}
func (littleEndian) Uint16(b []byte) uint16 {
b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func (littleEndian) PutUint16(b []byte, v uint16) {
b = b[:2:len(b)] // early bounds check to guarantee safety of writes below
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func (littleEndian) Uint32(b []byte) uint32 {
b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) PutUint32(b []byte, v uint32) {
b = b[:4:len(b)] // early bounds check to guarantee safety of writes below
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
......@@ -73,13 +73,13 @@ func (littleEndian) PutUint32(b []byte, v uint32) {
}
func (littleEndian) Uint64(b []byte) uint64 {
b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (littleEndian) PutUint64(b []byte, v uint64) {
b = b[:8:len(b)] // early bounds check to guarantee safety of writes below
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
......@@ -97,23 +97,23 @@ func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
func (bigEndian) Uint16(b []byte) uint16 {
b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
func (bigEndian) PutUint16(b []byte, v uint16) {
b = b[:2:len(b)] // early bounds check to guarantee safety of writes below
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func (bigEndian) Uint32(b []byte) uint32 {
b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) PutUint32(b []byte, v uint32) {
b = b[:4:len(b)] // early bounds check to guarantee safety of writes below
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
......@@ -121,13 +121,13 @@ func (bigEndian) PutUint32(b []byte, v uint32) {
}
func (bigEndian) Uint64(b []byte) uint64 {
b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func (bigEndian) PutUint64(b []byte, v uint64) {
b = b[:8:len(b)] // early bounds check to guarantee safety of writes below
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
......
......@@ -3,12 +3,14 @@
package main
import "math"
func f0(a []int) int {
a[0] = 1
a[0] = 1 // ERROR "Proved boolean IsInBounds$"
a[6] = 1
a[6] = 1 // ERROR "Proved boolean IsInBounds$"
a[5] = 1
a[5] = 1 // ERROR "Proved IsInBounds$"
a[5] = 1 // ERROR "Proved boolean IsInBounds$"
return 13
}
......@@ -17,15 +19,25 @@ func f1(a []int) int {
if len(a) <= 5 {
return 18
}
a[0] = 1
a[0] = 1 // ERROR "Proved non-negative bounds IsInBounds$"
a[0] = 1 // ERROR "Proved boolean IsInBounds$"
a[6] = 1
a[6] = 1 // ERROR "Proved boolean IsInBounds$"
a[5] = 1 // ERROR "Proved non-negative bounds IsInBounds$"
a[5] = 1 // ERROR "Proved IsInBounds$"
a[5] = 1 // ERROR "Proved boolean IsInBounds$"
return 26
}
func f1b(a []int, i int, j uint) int {
if i >= 0 && i < len(a) { // TODO: handle this case
return a[i]
}
if j < uint(len(a)) {
return a[j] // ERROR "Proved IsInBounds"
}
return 0
}
func f2(a []int) int {
for i := range a {
a[i] = i
......@@ -245,6 +257,164 @@ func f12(a []int, b int) {
useSlice(a[:b])
}
func f13a(a, b, c int, x bool) int {
if a > 12 {
if x {
if a < 12 { // ERROR "Disproved Less64$"
return 1
}
}
if x {
if a <= 12 { // ERROR "Disproved Leq64$"
return 2
}
}
if x {
if a == 12 { // ERROR "Disproved Eq64$"
return 3
}
}
if x {
if a >= 12 { // ERROR "Proved Geq64$"
return 4
}
}
if x {
if a > 12 { // ERROR "Proved boolean Greater64$"
return 5
}
}
return 6
}
return 0
}
func f13b(a int, x bool) int {
if a == -9 {
if x {
if a < -9 { // ERROR "Disproved Less64$"
return 7
}
}
if x {
if a <= -9 { // ERROR "Proved Leq64$"
return 8
}
}
if x {
if a == -9 { // ERROR "Proved boolean Eq64$"
return 9
}
}
if x {
if a >= -9 { // ERROR "Proved Geq64$"
return 10
}
}
if x {
if a > -9 { // ERROR "Disproved Greater64$"
return 11
}
}
return 12
}
return 0
}
func f13c(a int, x bool) int {
if a < 90 {
if x {
if a < 90 { // ERROR "Proved boolean Less64$"
return 13
}
}
if x {
if a <= 90 { // ERROR "Proved Leq64$"
return 14
}
}
if x {
if a == 90 { // ERROR "Disproved Eq64$"
return 15
}
}
if x {
if a >= 90 { // ERROR "Disproved Geq64$"
return 16
}
}
if x {
if a > 90 { // ERROR "Disproved Greater64$"
return 17
}
}
return 18
}
return 0
}
func f13d(a int) int {
if a < 5 {
if a < 9 { // ERROR "Proved Less64$"
return 1
}
}
return 0
}
func f13e(a int) int {
if a > 9 {
if a > 5 { // ERROR "Proved Greater64$"
return 1
}
}
return 0
}
func f13f(a int64) int64 {
if a > math.MaxInt64 {
// Unreachable, but prove doesn't know that.
if a == 0 {
return 1
}
}
return 0
}
func f13g(a int) int {
if a < 3 {
return 5
}
if a > 3 {
return 6
}
if a == 3 { // ERROR "Proved Eq64$"
return 7
}
return 8
}
func f13h(a int) int {
if a < 3 {
if a > 1 {
if a == 2 { // ERROR "Proved Eq64$"
return 5
}
}
}
return 0
}
func f13i(a uint) int {
if a == 0 {
return 1
}
if a > 0 { // ERROR "Proved Greater64U$"
return 2
}
return 3
}
//go:noinline
func useInt(a int) {
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment