Commit c3a871fd authored by Giovanni Bajo's avatar Giovanni Bajo

cmd/compile: in poset, make constant handling more flexible

Currently, constants in posets, in addition to being stored in
a DAG, are also stored as SSA values in a slice. This allows to
quickly go through all stored constants, but it's not easy to search
for a specific constant.

Following CLs will benefit from being able to quickly find
a constants by value in the poset, so change the constants
structure to a map. Since we're at it, don't store it as
*ssa.Value: poset always uses dense uint32 indices when
referring a node, so just switch to it.

Using a map also forces us to have a single node per
constant value: this is a good thing in the first place,
so this CL also make sure we never create two nodes for
the same constant value.

Change-Id: I099814578af35f935ebf14bc4767d607021f5f8b
Reviewed-on: https://go-review.googlesource.com/c/go/+/196781
Run-TryBot: Giovanni Bajo <rasky@develer.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarDavid Chase <drchase@google.com>
parent 50f11577
...@@ -42,16 +42,17 @@ func (bs bitset) Test(idx uint32) bool { ...@@ -42,16 +42,17 @@ func (bs bitset) Test(idx uint32) bool {
type undoType uint8 type undoType uint8
const ( const (
undoInvalid undoType = iota undoInvalid undoType = iota
undoCheckpoint // a checkpoint to group undo passes undoCheckpoint // a checkpoint to group undo passes
undoSetChl // change back left child of undo.idx to undo.edge undoSetChl // change back left child of undo.idx to undo.edge
undoSetChr // change back right child of undo.idx to undo.edge undoSetChr // change back right child of undo.idx to undo.edge
undoNonEqual // forget that SSA value undo.ID is non-equal to undo.idx (another ID) undoNonEqual // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
undoNewNode // remove new node created for SSA value undo.ID undoNewNode // remove new node created for SSA value undo.ID
undoAliasNode // unalias SSA value undo.ID so that it points back to node index undo.idx undoNewConstant // remove the constant node idx from the constants map
undoNewRoot // remove node undo.idx from root list undoAliasNode // unalias SSA value undo.ID so that it points back to node index undo.idx
undoChangeRoot // remove node undo.idx from root list, and put back undo.edge.Target instead undoNewRoot // remove node undo.idx from root list
undoMergeRoot // remove node undo.idx from root list, and put back its children instead undoChangeRoot // remove node undo.idx from root list, and put back undo.edge.Target instead
undoMergeRoot // remove node undo.idx from root list, and put back its children instead
) )
// posetUndo represents an undo pass to be performed. // posetUndo represents an undo pass to be performed.
...@@ -146,20 +147,20 @@ type posetNode struct { ...@@ -146,20 +147,20 @@ type posetNode struct {
// J K // J K
// //
type poset struct { type poset struct {
lastidx uint32 // last generated dense index lastidx uint32 // last generated dense index
flags uint8 // internal flags flags uint8 // internal flags
values map[ID]uint32 // map SSA values to dense indexes values map[ID]uint32 // map SSA values to dense indexes
constants []*Value // record SSA constants together with their value constants map[int64]uint32 // record SSA constants together with their value
nodes []posetNode // nodes (in all DAGs) nodes []posetNode // nodes (in all DAGs)
roots []uint32 // list of root nodes (forest) roots []uint32 // list of root nodes (forest)
noneq map[ID]bitset // non-equal relations noneq map[ID]bitset // non-equal relations
undo []posetUndo // undo chain undo []posetUndo // undo chain
} }
func newPoset() *poset { func newPoset() *poset {
return &poset{ return &poset{
values: make(map[ID]uint32), values: make(map[ID]uint32),
constants: make([]*Value, 0, 8), constants: make(map[int64]uint32, 8),
nodes: make([]posetNode, 1, 16), nodes: make([]posetNode, 1, 16),
roots: make([]uint32, 0, 4), roots: make([]uint32, 0, 4),
noneq: make(map[ID]bitset), noneq: make(map[ID]bitset),
...@@ -205,6 +206,11 @@ func (po *poset) upushalias(id ID, i2 uint32) { ...@@ -205,6 +206,11 @@ func (po *poset) upushalias(id ID, i2 uint32) {
po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2}) po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
} }
// upushconst pushes a new undo pass for a new constant
func (po *poset) upushconst(idx uint32, old uint32) {
po.undo = append(po.undo, posetUndo{typ: undoNewConstant, idx: idx, ID: ID(old)})
}
// addchild adds i2 as direct child of i1. // addchild adds i2 as direct child of i1.
func (po *poset) addchild(i1, i2 uint32, strict bool) { func (po *poset) addchild(i1, i2 uint32, strict bool) {
i1l, i1r := po.children(i1) i1l, i1r := po.children(i1)
...@@ -281,18 +287,33 @@ func (po *poset) newconst(n *Value) { ...@@ -281,18 +287,33 @@ func (po *poset) newconst(n *Value) {
panic("newconst on non-constant") panic("newconst on non-constant")
} }
// If this is the first constant, put it into a new root, as // If the same constant is already present in the poset through a different
// Value, just alias to it without allocating a new node.
val := n.AuxInt
if po.flags&posetFlagUnsigned != 0 {
val = int64(n.AuxUnsigned())
}
if c, found := po.constants[val]; found {
po.values[n.ID] = c
po.upushalias(n.ID, 0)
return
}
// Create the new node for this constant
i := po.newnode(n)
// If this is the first constant, put it as a new root, as
// we can't record an existing connection so we don't have // we can't record an existing connection so we don't have
// a specific DAG to add it to. Notice that we want all // a specific DAG to add it to. Notice that we want all
// constants to be in root #0, so make sure the new root // constants to be in root #0, so make sure the new root
// goes there. // goes there.
if len(po.constants) == 0 { if len(po.constants) == 0 {
idx := len(po.roots) idx := len(po.roots)
i := po.newnode(n)
po.roots = append(po.roots, i) po.roots = append(po.roots, i)
po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0] po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0]
po.upush(undoNewRoot, i, 0) po.upush(undoNewRoot, i, 0)
po.constants = append(po.constants, n) po.constants[val] = i
po.upushconst(i, 0)
return return
} }
...@@ -301,21 +322,20 @@ func (po *poset) newconst(n *Value) { ...@@ -301,21 +322,20 @@ func (po *poset) newconst(n *Value) {
// and the lower constant that is higher. // and the lower constant that is higher.
// The loop is duplicated to handle signed and unsigned comparison, // The loop is duplicated to handle signed and unsigned comparison,
// depending on how the poset was configured. // depending on how the poset was configured.
var lowerptr, higherptr *Value var lowerptr, higherptr uint32
if po.flags&posetFlagUnsigned != 0 { if po.flags&posetFlagUnsigned != 0 {
var lower, higher uint64 var lower, higher uint64
val1 := n.AuxUnsigned() val1 := n.AuxUnsigned()
for _, ptr := range po.constants { for val2, ptr := range po.constants {
val2 := ptr.AuxUnsigned() val2 := uint64(val2)
if val1 == val2 { if val1 == val2 {
po.aliasnode(ptr, n) panic("unreachable")
return
} }
if val2 < val1 && (lowerptr == nil || val2 > lower) { if val2 < val1 && (lowerptr == 0 || val2 > lower) {
lower = val2 lower = val2
lowerptr = ptr lowerptr = ptr
} else if val2 > val1 && (higherptr == nil || val2 < higher) { } else if val2 > val1 && (higherptr == 0 || val2 < higher) {
higher = val2 higher = val2
higherptr = ptr higherptr = ptr
} }
...@@ -323,23 +343,21 @@ func (po *poset) newconst(n *Value) { ...@@ -323,23 +343,21 @@ func (po *poset) newconst(n *Value) {
} else { } else {
var lower, higher int64 var lower, higher int64
val1 := n.AuxInt val1 := n.AuxInt
for _, ptr := range po.constants { for val2, ptr := range po.constants {
val2 := ptr.AuxInt
if val1 == val2 { if val1 == val2 {
po.aliasnode(ptr, n) panic("unreachable")
return
} }
if val2 < val1 && (lowerptr == nil || val2 > lower) { if val2 < val1 && (lowerptr == 0 || val2 > lower) {
lower = val2 lower = val2
lowerptr = ptr lowerptr = ptr
} else if val2 > val1 && (higherptr == nil || val2 < higher) { } else if val2 > val1 && (higherptr == 0 || val2 < higher) {
higher = val2 higher = val2
higherptr = ptr higherptr = ptr
} }
} }
} }
if lowerptr == nil && higherptr == nil { if lowerptr == 0 && higherptr == 0 {
// This should not happen, as at least one // This should not happen, as at least one
// other constant must exist if we get here. // other constant must exist if we get here.
panic("no constant found") panic("no constant found")
...@@ -350,18 +368,17 @@ func (po *poset) newconst(n *Value) { ...@@ -350,18 +368,17 @@ func (po *poset) newconst(n *Value) {
// of them, depending on what other constants are present in the poset. // of them, depending on what other constants are present in the poset.
// Notice that we always link constants together, so they // Notice that we always link constants together, so they
// are always part of the same DAG. // are always part of the same DAG.
i := po.newnode(n)
switch { switch {
case lowerptr != nil && higherptr != nil: case lowerptr != 0 && higherptr != 0:
// Both bounds are present, record lower < n < higher. // Both bounds are present, record lower < n < higher.
po.addchild(po.values[lowerptr.ID], i, true) po.addchild(lowerptr, i, true)
po.addchild(i, po.values[higherptr.ID], true) po.addchild(i, higherptr, true)
case lowerptr != nil: case lowerptr != 0:
// Lower bound only, record lower < n. // Lower bound only, record lower < n.
po.addchild(po.values[lowerptr.ID], i, true) po.addchild(lowerptr, i, true)
case higherptr != nil: case higherptr != 0:
// Higher bound only. To record n < higher, we need // Higher bound only. To record n < higher, we need
// a dummy root: // a dummy root:
// //
...@@ -373,7 +390,7 @@ func (po *poset) newconst(n *Value) { ...@@ -373,7 +390,7 @@ func (po *poset) newconst(n *Value) {
// \ / // \ /
// higher // higher
// //
i2 := po.values[higherptr.ID] i2 := higherptr
r2 := po.findroot(i2) r2 := po.findroot(i2)
if r2 != po.roots[0] { // all constants should be in root #0 if r2 != po.roots[0] { // all constants should be in root #0
panic("constant not in root #0") panic("constant not in root #0")
...@@ -386,7 +403,8 @@ func (po *poset) newconst(n *Value) { ...@@ -386,7 +403,8 @@ func (po *poset) newconst(n *Value) {
po.addchild(i, i2, true) po.addchild(i, i2, true)
} }
po.constants = append(po.constants, n) po.constants[val] = i
po.upushconst(i, 0)
} }
// aliasnode records that n2 is an alias of n1 // aliasnode records that n2 is an alias of n1
...@@ -422,6 +440,19 @@ func (po *poset) aliasnode(n1, n2 *Value) { ...@@ -422,6 +440,19 @@ func (po *poset) aliasnode(n1, n2 *Value) {
po.upushalias(k, i2) po.upushalias(k, i2)
} }
} }
if n2.isGenericIntConst() {
val := n2.AuxInt
if po.flags&posetFlagUnsigned != 0 {
val = int64(n2.AuxUnsigned())
}
if po.constants[val] != i2 {
panic("aliasing constant which is not registered")
}
po.constants[val] = i1
po.upushconst(i1, i2)
}
} else { } else {
// n2.ID wasn't seen before, so record it as alias to i1 // n2.ID wasn't seen before, so record it as alias to i1
po.values[n2.ID] = i1 po.values[n2.ID] = i1
...@@ -631,11 +662,7 @@ func (po *poset) CheckIntegrity() { ...@@ -631,11 +662,7 @@ func (po *poset) CheckIntegrity() {
// Record which index is a constant // Record which index is a constant
constants := newBitset(int(po.lastidx + 1)) constants := newBitset(int(po.lastidx + 1))
for _, c := range po.constants { for _, c := range po.constants {
if idx, ok := po.values[c.ID]; !ok { constants.Set(c)
panic("node missing for constant")
} else {
constants.Set(idx)
}
} }
// Verify that each node appears in a single DAG, and that // Verify that each node appears in a single DAG, and that
...@@ -732,15 +759,10 @@ func (po *poset) DotDump(fn string, title string) error { ...@@ -732,15 +759,10 @@ func (po *poset) DotDump(fn string, title string) error {
names[i] = s names[i] = s
} }
// Create constant mapping // Create reverse constant mapping
consts := make(map[uint32]int64) consts := make(map[uint32]int64)
for _, v := range po.constants { for val, idx := range po.constants {
idx := po.values[v.ID] consts[idx] = val
if po.flags&posetFlagUnsigned != 0 {
consts[idx] = int64(v.AuxUnsigned())
} else {
consts[idx] = v.AuxInt
}
} }
fmt.Fprintf(f, "digraph poset {\n") fmt.Fprintf(f, "digraph poset {\n")
...@@ -1171,10 +1193,25 @@ func (po *poset) Undo() { ...@@ -1171,10 +1193,25 @@ func (po *poset) Undo() {
po.nodes = po.nodes[:pass.idx] po.nodes = po.nodes[:pass.idx]
po.lastidx-- po.lastidx--
// If it was the last inserted constant, remove it case undoNewConstant:
nc := len(po.constants) // FIXME: remove this O(n) loop
if nc > 0 && po.constants[nc-1].ID == pass.ID { var val int64
po.constants = po.constants[:nc-1] var i uint32
for val, i = range po.constants {
if i == pass.idx {
break
}
}
if i != pass.idx {
panic("constant not found in undo pass")
}
if pass.ID == 0 {
delete(po.constants, val)
} else {
// Restore previous index as constant node
// (also restoring the invariant on correct bounds)
oldidx := uint32(pass.ID)
po.constants[val] = oldidx
} }
case undoAliasNode: case undoAliasNode:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment